Coverage for lice2/helpers.py: 100%

184 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-13 11:11 +0000

1"""Helper functions for LICE2.""" 

2 

3import getpass 

4import json 

5import os 

6import re 

7import subprocess 

8import sys 

9from contextlib import closing 

10from datetime import datetime 

11from importlib import resources 

12from io import StringIO 

13from pathlib import Path 

14from types import SimpleNamespace 

15from typing import Union 

16 

17import typer 

18from rich.console import Console 

19from rich.table import Table 

20from rich.text import Text 

21 

22from lice2.config import settings 

23from lice2.constants import LANG_CMT, LANGS, LICENSES 

24 

25 

26def clean_path(p: str) -> str: 

27 """Clean a path. 

28 

29 Expand user and environment variables anensuring absolute path. 

30 """ 

31 expanded = os.path.expandvars(Path(p).expanduser()) 

32 return str(Path(expanded).resolve()) 

33 

34 

35def guess_organization() -> str: 

36 """First, try to get fom the settings file. 

37 

38 If this is blank, guess the organization from `git config`. 

39 If that can't be found, fall back to $USER environment variable. 

40 """ 

41 if settings.organization: 

42 return settings.organization 

43 

44 try: 

45 stdout = subprocess.check_output("git config --get user.name".split()) # noqa: S603 

46 org = stdout.strip().decode("UTF-8") 

47 except subprocess.CalledProcessError: 

48 org = getpass.getuser() 

49 return org 

50 

51 

52def get_context(args: SimpleNamespace) -> dict[str, str]: 

53 """Return the context vars from the provided args.""" 

54 return { 

55 "year": args.year, 

56 "organization": args.organization, 

57 "project": args.project, 

58 } 

59 

60 

61def get_lang(args: SimpleNamespace) -> str: 

62 """Check the specified language is supported.""" 

63 lang: str = args.language 

64 if lang and lang not in LANGS: 

65 sys.stderr.write( 

66 "I do not know about a language ending with " 

67 f"extension '{lang}'.\n" 

68 "Please send a pull request adding this language to\n" 

69 "https://github.com/seapagan/lice2. Thanks!\n" 

70 ) 

71 raise typer.Exit(1) 

72 return lang 

73 

74 

75def list_licenses() -> None: 

76 """List available licenses and their template variables.""" 

77 table = Table(title="Available Licenses") 

78 table.add_column("License Name") 

79 table.add_column("Variables") 

80 for license_name in LICENSES: 

81 template = load_package_template(license_name) 

82 var_list = extract_vars(template) 

83 table.add_row(license_name, ", ".join(var_list)) 

84 

85 console = Console() 

86 console.print(table) 

87 

88 raise typer.Exit(0) 

89 

90 

91def list_languages() -> None: 

92 """List available source code formatting languages.""" 

93 console = Console(width=80) 

94 languages = sorted(LANGS.keys()) 

95 text = Text(", ".join(languages)) 

96 console.print( 

97 "The following source code formatting languages are supported:\n" 

98 ) 

99 console.print(text) 

100 

101 raise typer.Exit(0) 

102 

103 

104def load_file_template(path: str) -> StringIO: 

105 """Load template from the specified filesystem path.""" 

106 template = StringIO() 

107 if not Path(path).exists(): 

108 message = f"path does not exist: {path}" 

109 raise ValueError(message) 

110 with Path(clean_path(path)).open(mode="rb") as infile: # opened as binary 

111 for line in infile: 

112 template.write(line.decode("utf-8")) # ensure utf-8 

113 return template 

114 

115 

116def get_template_content(license_name: str, *, header: bool = False) -> str: 

117 """Get the content of a license template as a string. 

118 

119 Args: 

120 license_name: Name of the license template to load 

121 header: If True, load the header template instead of the full license 

122 

123 Returns: 

124 The template content as a string 

125 

126 Raises: 

127 FileNotFoundError: If the template doesn't exist 

128 """ 

129 filename = ( 

130 f"template-{license_name}-header.txt" 

131 if header 

132 else f"template-{license_name}.txt" 

133 ) 

134 package_name = __package__ or __name__.split(".")[0] 

135 template_file = resources.files(package_name) / "templates" / filename 

136 return template_file.read_text(encoding="utf-8") 

137 

138 

139def load_package_template( 

140 license_name: str, *, header: bool = False 

141) -> StringIO: 

142 """Load license template distributed with package. 

143 

144 Args: 

145 license_name: Name of the license template to load 

146 header: If True, load the header template instead of the full license 

147 

148 Returns: 

149 StringIO object containing the template content 

150 

151 Raises: 

152 FileNotFoundError: If the template doesn't exist 

153 """ 

154 content = StringIO() 

155 content.write(get_template_content(license_name, header=header)) 

156 return content 

157 

158 

159def extract_vars(template: StringIO) -> list[str]: 

160 """Extract variables from template. 

161 

162 Variables are enclosed in double curly braces. 

163 """ 

164 keys: set[str] = set() 

165 for match in re.finditer(r"\{\{ (?P<key>\w+) \}\}", template.getvalue()): 

166 keys.add(match.groups()[0]) 

167 return sorted(keys) 

168 

169 

170def generate_license(template: StringIO, context: dict[str, str]) -> StringIO: 

171 """Generate a license. 

172 

173 We extract variables from the template and replace them with the 

174 corresponding values in the given context. 

175 

176 This could be done with a template engine like 'Jinja2, but we're keeping it 

177 simple. 

178 """ 

179 out = StringIO() 

180 with closing(template): 

181 content = template.getvalue() 

182 for key in extract_vars(template): 

183 if key not in context: 

184 message = f"{key} is missing from the template context" 

185 raise ValueError(message) 

186 content = content.replace(f"{{{{ {key} }}}}", context[key]) 

187 out.write(content) 

188 return out 

189 

190 

191def get_comments(lang: str, *, legacy: bool) -> tuple[str, str, str]: 

192 """Adjust the comment strings for the given language. 

193 

194 The way it was done previously, extra whitespace was added to the start of 

195 the comment lines if the comment was a block comment. This tries to fix 

196 that. 

197 """ 

198 prefix, comment, postfix = LANG_CMT[LANGS[lang]] 

199 if legacy: 

200 return ( 

201 f"{prefix}\n", 

202 f"{comment} ", 

203 f"{postfix}\n", 

204 ) 

205 

206 if comment: 

207 comment = f"{comment} " 

208 prefix = f"{prefix}\n" if prefix else "" 

209 postfix = f"{postfix}\n" if postfix else "" 

210 return prefix, comment, postfix 

211 

212 

213def format_license( 

214 template: StringIO, lang: str, *, legacy: bool = False 

215) -> StringIO: 

216 """Format the StringIO template object for specified lang string. 

217 

218 Return StringIO object formatted 

219 """ 

220 if not lang: 

221 lang = "txt" 

222 

223 prefix, comment, postfix = get_comments(lang, legacy=legacy) 

224 

225 out = StringIO() 

226 

227 with closing(template): 

228 template.seek(0) # from the start of the buffer 

229 out.write(prefix) 

230 for line in template: 

231 # ensure no extra whitespace is added for blank lines 

232 out.write(comment if line.strip() else comment.rstrip()) 

233 out.write(line) 

234 out.write(postfix) 

235 

236 return out 

237 

238 

239def get_suffix(name: str) -> Union[str, None]: 

240 """Check if file name have valid suffix for formatting. 

241 

242 If have suffix, return it else return None. 

243 """ 

244 a = name.count(".") 

245 if a: 

246 ext = name.split(".")[-1] 

247 if ext in LANGS: 

248 return ext 

249 return None 

250 

251 

252def list_vars(args: SimpleNamespace, license_name: str) -> None: 

253 """List the variables for the given template.""" 

254 context = get_context(args) 

255 

256 if args.template_path: 

257 template = load_file_template(args.template_path) 

258 else: 

259 template = load_package_template(license_name) 

260 

261 var_list = extract_vars(template) 

262 

263 if var_list: 

264 sys.stdout.write( 

265 "The %s license template contains the following variables " 

266 "and defaults:\n" % (args.template_path or license_name) 

267 ) 

268 for v in var_list: 

269 if v in context: 

270 sys.stdout.write(f" {v} = {context[v]}\n") 

271 else: 

272 sys.stdout.write(f" {v}\n") 

273 else: 

274 sys.stdout.write( 

275 f"The {args.template_path or license_name} license template " 

276 "contains no variables.\n" 

277 ) 

278 

279 raise typer.Exit(0) 

280 

281 

282def generate_header(args: SimpleNamespace, lang: str) -> None: 

283 """Generate a file header for the given license and language.""" 

284 if args.template_path: 

285 template = load_file_template(args.template_path) 

286 else: 

287 try: 

288 template = load_package_template(args.license, header=True) 

289 except OSError: 

290 sys.stderr.write( 

291 "Sorry, no source headers are available for " 

292 f"{args.license}.\n" 

293 ) 

294 raise typer.Exit(1) from None 

295 

296 with closing(template): 

297 content = generate_license(template, get_context(args)) 

298 out = format_license(content, lang, legacy=args.legacy) 

299 out.seek(0) 

300 if not args.clipboard: 

301 sys.stdout.write(out.getvalue()) 

302 else: 

303 try: 

304 import pyperclip 

305 

306 pyperclip.copy(out.getvalue()) 

307 typer.secho( 

308 "License text copied to clipboard", 

309 fg=typer.colors.BRIGHT_GREEN, 

310 ) 

311 except pyperclip.PyperclipException as exc: 

312 typer.secho( 

313 f"Error copying to clipboard: {exc}", 

314 fg=typer.colors.BRIGHT_RED, 

315 ) 

316 raise typer.Exit(2) from None 

317 raise typer.Exit(0) 

318 

319 

320def validate_year(string: str) -> str: 

321 """Validate the year is a four-digit number.""" 

322 if not re.match(r"^\d{4}$", string): 

323 message = "Must be a four-digit year" 

324 raise typer.BadParameter(message) 

325 return string 

326 

327 

328def validate_license(license_name: str) -> str: 

329 """Validate the license is in the list of available licenses.""" 

330 if license_name not in LICENSES: 

331 message = ( 

332 f"License '{license_name}' not found - please run 'lice " 

333 "--licenses' to get a list of available licenses." 

334 ) 

335 raise typer.BadParameter(message) 

336 return license_name 

337 

338 

339def copy_to_clipboard(out: StringIO) -> None: 

340 """Try to copy to clipboard, exit with error if not possible.""" 

341 try: 

342 import pyperclip 

343 

344 pyperclip.copy(out.getvalue()) 

345 typer.secho( 

346 "License text copied to clipboard", 

347 fg=typer.colors.BRIGHT_GREEN, 

348 ) 

349 except pyperclip.PyperclipException as exc: 

350 typer.secho( 

351 f"Error copying to clipboard: {exc}", 

352 fg=typer.colors.BRIGHT_RED, 

353 ) 

354 raise typer.Exit(2) from None 

355 

356 

357def get_metadata(args: SimpleNamespace) -> None: 

358 """Return metadata for the package as a JSON string.""" 

359 licenses = LICENSES 

360 languages = list(LANGS.keys()) 

361 organization = args.organization 

362 project = args.project 

363 

364 metadata = { 

365 "languages": languages, 

366 "licenses": licenses, 

367 "organization": organization, 

368 "project": project, 

369 } 

370 

371 sys.stdout.write(json.dumps(metadata) + "\n") 

372 

373 raise typer.Exit(0) 

374 

375 

376def get_local_year() -> str: 

377 """Return the current year using the local timezone.""" 

378 return f"{datetime.now().astimezone().year}"