Coverage for lice2/helpers.py: 100%
184 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-13 11:11 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-13 11:11 +0000
1"""Helper functions for LICE2."""
3import getpass
4import json
5import os
6import re
7import subprocess
8import sys
9from contextlib import closing
10from datetime import datetime
11from importlib import resources
12from io import StringIO
13from pathlib import Path
14from types import SimpleNamespace
15from typing import Union
17import typer
18from rich.console import Console
19from rich.table import Table
20from rich.text import Text
22from lice2.config import settings
23from lice2.constants import LANG_CMT, LANGS, LICENSES
26def clean_path(p: str) -> str:
27 """Clean a path.
29 Expand user and environment variables anensuring absolute path.
30 """
31 expanded = os.path.expandvars(Path(p).expanduser())
32 return str(Path(expanded).resolve())
35def guess_organization() -> str:
36 """First, try to get fom the settings file.
38 If this is blank, guess the organization from `git config`.
39 If that can't be found, fall back to $USER environment variable.
40 """
41 if settings.organization:
42 return settings.organization
44 try:
45 stdout = subprocess.check_output("git config --get user.name".split()) # noqa: S603
46 org = stdout.strip().decode("UTF-8")
47 except subprocess.CalledProcessError:
48 org = getpass.getuser()
49 return org
52def get_context(args: SimpleNamespace) -> dict[str, str]:
53 """Return the context vars from the provided args."""
54 return {
55 "year": args.year,
56 "organization": args.organization,
57 "project": args.project,
58 }
61def get_lang(args: SimpleNamespace) -> str:
62 """Check the specified language is supported."""
63 lang: str = args.language
64 if lang and lang not in LANGS:
65 sys.stderr.write(
66 "I do not know about a language ending with "
67 f"extension '{lang}'.\n"
68 "Please send a pull request adding this language to\n"
69 "https://github.com/seapagan/lice2. Thanks!\n"
70 )
71 raise typer.Exit(1)
72 return lang
75def list_licenses() -> None:
76 """List available licenses and their template variables."""
77 table = Table(title="Available Licenses")
78 table.add_column("License Name")
79 table.add_column("Variables")
80 for license_name in LICENSES:
81 template = load_package_template(license_name)
82 var_list = extract_vars(template)
83 table.add_row(license_name, ", ".join(var_list))
85 console = Console()
86 console.print(table)
88 raise typer.Exit(0)
91def list_languages() -> None:
92 """List available source code formatting languages."""
93 console = Console(width=80)
94 languages = sorted(LANGS.keys())
95 text = Text(", ".join(languages))
96 console.print(
97 "The following source code formatting languages are supported:\n"
98 )
99 console.print(text)
101 raise typer.Exit(0)
104def load_file_template(path: str) -> StringIO:
105 """Load template from the specified filesystem path."""
106 template = StringIO()
107 if not Path(path).exists():
108 message = f"path does not exist: {path}"
109 raise ValueError(message)
110 with Path(clean_path(path)).open(mode="rb") as infile: # opened as binary
111 for line in infile:
112 template.write(line.decode("utf-8")) # ensure utf-8
113 return template
116def get_template_content(license_name: str, *, header: bool = False) -> str:
117 """Get the content of a license template as a string.
119 Args:
120 license_name: Name of the license template to load
121 header: If True, load the header template instead of the full license
123 Returns:
124 The template content as a string
126 Raises:
127 FileNotFoundError: If the template doesn't exist
128 """
129 filename = (
130 f"template-{license_name}-header.txt"
131 if header
132 else f"template-{license_name}.txt"
133 )
134 package_name = __package__ or __name__.split(".")[0]
135 template_file = resources.files(package_name) / "templates" / filename
136 return template_file.read_text(encoding="utf-8")
139def load_package_template(
140 license_name: str, *, header: bool = False
141) -> StringIO:
142 """Load license template distributed with package.
144 Args:
145 license_name: Name of the license template to load
146 header: If True, load the header template instead of the full license
148 Returns:
149 StringIO object containing the template content
151 Raises:
152 FileNotFoundError: If the template doesn't exist
153 """
154 content = StringIO()
155 content.write(get_template_content(license_name, header=header))
156 return content
159def extract_vars(template: StringIO) -> list[str]:
160 """Extract variables from template.
162 Variables are enclosed in double curly braces.
163 """
164 keys: set[str] = set()
165 for match in re.finditer(r"\{\{ (?P<key>\w+) \}\}", template.getvalue()):
166 keys.add(match.groups()[0])
167 return sorted(keys)
170def generate_license(template: StringIO, context: dict[str, str]) -> StringIO:
171 """Generate a license.
173 We extract variables from the template and replace them with the
174 corresponding values in the given context.
176 This could be done with a template engine like 'Jinja2, but we're keeping it
177 simple.
178 """
179 out = StringIO()
180 with closing(template):
181 content = template.getvalue()
182 for key in extract_vars(template):
183 if key not in context:
184 message = f"{key} is missing from the template context"
185 raise ValueError(message)
186 content = content.replace(f"{{{{ {key} }}}}", context[key])
187 out.write(content)
188 return out
191def get_comments(lang: str, *, legacy: bool) -> tuple[str, str, str]:
192 """Adjust the comment strings for the given language.
194 The way it was done previously, extra whitespace was added to the start of
195 the comment lines if the comment was a block comment. This tries to fix
196 that.
197 """
198 prefix, comment, postfix = LANG_CMT[LANGS[lang]]
199 if legacy:
200 return (
201 f"{prefix}\n",
202 f"{comment} ",
203 f"{postfix}\n",
204 )
206 if comment:
207 comment = f"{comment} "
208 prefix = f"{prefix}\n" if prefix else ""
209 postfix = f"{postfix}\n" if postfix else ""
210 return prefix, comment, postfix
213def format_license(
214 template: StringIO, lang: str, *, legacy: bool = False
215) -> StringIO:
216 """Format the StringIO template object for specified lang string.
218 Return StringIO object formatted
219 """
220 if not lang:
221 lang = "txt"
223 prefix, comment, postfix = get_comments(lang, legacy=legacy)
225 out = StringIO()
227 with closing(template):
228 template.seek(0) # from the start of the buffer
229 out.write(prefix)
230 for line in template:
231 # ensure no extra whitespace is added for blank lines
232 out.write(comment if line.strip() else comment.rstrip())
233 out.write(line)
234 out.write(postfix)
236 return out
239def get_suffix(name: str) -> Union[str, None]:
240 """Check if file name have valid suffix for formatting.
242 If have suffix, return it else return None.
243 """
244 a = name.count(".")
245 if a:
246 ext = name.split(".")[-1]
247 if ext in LANGS:
248 return ext
249 return None
252def list_vars(args: SimpleNamespace, license_name: str) -> None:
253 """List the variables for the given template."""
254 context = get_context(args)
256 if args.template_path:
257 template = load_file_template(args.template_path)
258 else:
259 template = load_package_template(license_name)
261 var_list = extract_vars(template)
263 if var_list:
264 sys.stdout.write(
265 "The %s license template contains the following variables "
266 "and defaults:\n" % (args.template_path or license_name)
267 )
268 for v in var_list:
269 if v in context:
270 sys.stdout.write(f" {v} = {context[v]}\n")
271 else:
272 sys.stdout.write(f" {v}\n")
273 else:
274 sys.stdout.write(
275 f"The {args.template_path or license_name} license template "
276 "contains no variables.\n"
277 )
279 raise typer.Exit(0)
282def generate_header(args: SimpleNamespace, lang: str) -> None:
283 """Generate a file header for the given license and language."""
284 if args.template_path:
285 template = load_file_template(args.template_path)
286 else:
287 try:
288 template = load_package_template(args.license, header=True)
289 except OSError:
290 sys.stderr.write(
291 "Sorry, no source headers are available for "
292 f"{args.license}.\n"
293 )
294 raise typer.Exit(1) from None
296 with closing(template):
297 content = generate_license(template, get_context(args))
298 out = format_license(content, lang, legacy=args.legacy)
299 out.seek(0)
300 if not args.clipboard:
301 sys.stdout.write(out.getvalue())
302 else:
303 try:
304 import pyperclip
306 pyperclip.copy(out.getvalue())
307 typer.secho(
308 "License text copied to clipboard",
309 fg=typer.colors.BRIGHT_GREEN,
310 )
311 except pyperclip.PyperclipException as exc:
312 typer.secho(
313 f"Error copying to clipboard: {exc}",
314 fg=typer.colors.BRIGHT_RED,
315 )
316 raise typer.Exit(2) from None
317 raise typer.Exit(0)
320def validate_year(string: str) -> str:
321 """Validate the year is a four-digit number."""
322 if not re.match(r"^\d{4}$", string):
323 message = "Must be a four-digit year"
324 raise typer.BadParameter(message)
325 return string
328def validate_license(license_name: str) -> str:
329 """Validate the license is in the list of available licenses."""
330 if license_name not in LICENSES:
331 message = (
332 f"License '{license_name}' not found - please run 'lice "
333 "--licenses' to get a list of available licenses."
334 )
335 raise typer.BadParameter(message)
336 return license_name
339def copy_to_clipboard(out: StringIO) -> None:
340 """Try to copy to clipboard, exit with error if not possible."""
341 try:
342 import pyperclip
344 pyperclip.copy(out.getvalue())
345 typer.secho(
346 "License text copied to clipboard",
347 fg=typer.colors.BRIGHT_GREEN,
348 )
349 except pyperclip.PyperclipException as exc:
350 typer.secho(
351 f"Error copying to clipboard: {exc}",
352 fg=typer.colors.BRIGHT_RED,
353 )
354 raise typer.Exit(2) from None
357def get_metadata(args: SimpleNamespace) -> None:
358 """Return metadata for the package as a JSON string."""
359 licenses = LICENSES
360 languages = list(LANGS.keys())
361 organization = args.organization
362 project = args.project
364 metadata = {
365 "languages": languages,
366 "licenses": licenses,
367 "organization": organization,
368 "project": project,
369 }
371 sys.stdout.write(json.dumps(metadata) + "\n")
373 raise typer.Exit(0)
376def get_local_year() -> str:
377 """Return the current year using the local timezone."""
378 return f"{datetime.now().astimezone().year}"