Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/scipy/_lib/_uarray/_backend.py : 43%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import typing
2import inspect
3import functools
4from . import _uarray # type: ignore
5import copyreg # type: ignore
6import atexit
7import pickle
9ArgumentExtractorType = typing.Callable[..., typing.Tuple["Dispatchable", ...]]
10ArgumentReplacerType = typing.Callable[
11 [typing.Tuple, typing.Dict, typing.Tuple], typing.Tuple[typing.Tuple, typing.Dict]
12]
14from ._uarray import ( # type: ignore
15 BackendNotImplementedError,
16 _Function,
17 _SkipBackendContext,
18 _SetBackendContext,
19)
21__all__ = [
22 "set_backend",
23 "set_global_backend",
24 "skip_backend",
25 "register_backend",
26 "clear_backends",
27 "create_multimethod",
28 "generate_multimethod",
29 "_Function",
30 "BackendNotImplementedError",
31 "Dispatchable",
32 "wrap_single_convertor",
33 "all_of_type",
34 "mark_as",
35]
38def unpickle_function(mod_name, qname):
39 import importlib
41 try:
42 module = importlib.import_module(mod_name)
43 func = getattr(module, qname)
44 return func
45 except (ImportError, AttributeError) as e:
46 from pickle import UnpicklingError
48 raise UnpicklingError from e
51def pickle_function(func):
52 mod_name = getattr(func, "__module__", None)
53 qname = getattr(func, "__qualname__", None)
55 try:
56 test = unpickle_function(mod_name, qname)
57 except pickle.UnpicklingError:
58 test = None
60 if test is not func:
61 raise pickle.PicklingError(
62 "Can't pickle {}: it's not the same object as {}".format(func, test)
63 )
65 return unpickle_function, (mod_name, qname)
68copyreg.pickle(_Function, pickle_function)
69atexit.register(_uarray.clear_all_globals)
72def create_multimethod(*args, **kwargs):
73 """
74 Creates a decorator for generating multimethods.
76 This function creates a decorator that can be used with an argument
77 extractor in order to generate a multimethod. Other than for the
78 argument extractor, all arguments are passed on to
79 :obj:`generate_multimethod`.
81 See Also
82 --------
83 generate_multimethod
84 Generates a multimethod.
85 """
87 def wrapper(a):
88 return generate_multimethod(a, *args, **kwargs)
90 return wrapper
93def generate_multimethod(
94 argument_extractor: ArgumentExtractorType,
95 argument_replacer: ArgumentReplacerType,
96 domain: str,
97 default: typing.Optional[typing.Callable] = None,
98):
99 """
100 Generates a multimethod.
102 Parameters
103 ----------
104 argument_extractor : ArgumentExtractorType
105 A callable which extracts the dispatchable arguments. Extracted arguments
106 should be marked by the :obj:`Dispatchable` class. It has the same signature
107 as the desired multimethod.
108 argument_replacer : ArgumentReplacerType
109 A callable with the signature (args, kwargs, dispatchables), which should also
110 return an (args, kwargs) pair with the dispatchables replaced inside the args/kwargs.
111 domain : str
112 A string value indicating the domain of this multimethod.
113 default: Optional[Callable], optional
114 The default implementation of this multimethod, where ``None`` (the default) specifies
115 there is no default implementation.
117 Examples
118 --------
119 In this example, ``a`` is to be dispatched over, so we return it, while marking it as an ``int``.
120 The trailing comma is needed because the args have to be returned as an iterable.
122 >>> def override_me(a, b):
123 ... return Dispatchable(a, int),
125 Next, we define the argument replacer that replaces the dispatchables inside args/kwargs with the
126 supplied ones.
128 >>> def override_replacer(args, kwargs, dispatchables):
129 ... return (dispatchables[0], args[1]), {}
131 Next, we define the multimethod.
133 >>> overridden_me = generate_multimethod(
134 ... override_me, override_replacer, "ua_examples"
135 ... )
137 Notice that there's no default implementation, unless you supply one.
139 >>> overridden_me(1, "a")
140 Traceback (most recent call last):
141 ...
142 uarray.backend.BackendNotImplementedError: ...
143 >>> overridden_me2 = generate_multimethod(
144 ... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
145 ... )
146 >>> overridden_me2(1, "a")
147 (1, 'a')
149 See Also
150 --------
151 uarray
152 See the module documentation for how to override the method by creating backends.
153 """
154 kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
155 ua_func = _Function(
156 argument_extractor,
157 argument_replacer,
158 domain,
159 arg_defaults,
160 kw_defaults,
161 default,
162 )
164 return functools.update_wrapper(ua_func, argument_extractor)
167def set_backend(backend, coerce=False, only=False):
168 """
169 A context manager that sets the preferred backend.
171 Parameters
172 ----------
173 backend
174 The backend to set.
175 coerce
176 Whether or not to coerce to a specific backend's types. Implies ``only``.
177 only
178 Whether or not this should be the last backend to try.
180 See Also
181 --------
182 skip_backend: A context manager that allows skipping of backends.
183 set_global_backend: Set a single, global backend for a domain.
184 """
185 try:
186 return backend.__ua_cache__["set", coerce, only]
187 except AttributeError:
188 backend.__ua_cache__ = {}
189 except KeyError:
190 pass
192 ctx = _SetBackendContext(backend, coerce, only)
193 backend.__ua_cache__["set", coerce, only] = ctx
194 return ctx
197def skip_backend(backend):
198 """
199 A context manager that allows one to skip a given backend from processing
200 entirely. This allows one to use another backend's code in a library that
201 is also a consumer of the same backend.
203 Parameters
204 ----------
205 backend
206 The backend to skip.
208 See Also
209 --------
210 set_backend: A context manager that allows setting of backends.
211 set_global_backend: Set a single, global backend for a domain.
212 """
213 try:
214 return backend.__ua_cache__["skip"]
215 except AttributeError:
216 backend.__ua_cache__ = {}
217 except KeyError:
218 pass
220 ctx = _SkipBackendContext(backend)
221 backend.__ua_cache__["skip"] = ctx
222 return ctx
225def get_defaults(f):
226 sig = inspect.signature(f)
227 kw_defaults = {}
228 arg_defaults = []
229 opts = set()
230 for k, v in sig.parameters.items():
231 if v.default is not inspect.Parameter.empty:
232 kw_defaults[k] = v.default
233 if v.kind in (
234 inspect.Parameter.POSITIONAL_ONLY,
235 inspect.Parameter.POSITIONAL_OR_KEYWORD,
236 ):
237 arg_defaults.append(v.default)
238 opts.add(k)
240 return kw_defaults, tuple(arg_defaults), opts
243def set_global_backend(backend, coerce=False, only=False):
244 """
245 This utility method replaces the default backend for permanent use. It
246 will be tried in the list of backends automatically, unless the
247 ``only`` flag is set on a backend. This will be the first tried
248 backend outside the :obj:`set_backend` context manager.
250 Note that this method is not thread-safe.
252 .. warning::
253 We caution library authors against using this function in
254 their code. We do *not* support this use-case. This function
255 is meant to be used only by users themselves, or by a reference
256 implementation, if one exists.
258 Parameters
259 ----------
260 backend
261 The backend to register.
263 See Also
264 --------
265 set_backend: A context manager that allows setting of backends.
266 skip_backend: A context manager that allows skipping of backends.
267 """
268 _uarray.set_global_backend(backend, coerce, only)
271def register_backend(backend):
272 """
273 This utility method sets registers backend for permanent use. It
274 will be tried in the list of backends automatically, unless the
275 ``only`` flag is set on a backend.
277 Note that this method is not thread-safe.
279 Parameters
280 ----------
281 backend
282 The backend to register.
283 """
284 _uarray.register_backend(backend)
287def clear_backends(domain, registered=True, globals=False):
288 """
289 This utility method clears registered backends.
291 .. warning::
292 We caution library authors against using this function in
293 their code. We do *not* support this use-case. This function
294 is meant to be used only by the users themselves.
296 .. warning::
297 Do NOT use this method inside a multimethod call, or the
298 program is likely to crash.
300 Parameters
301 ----------
302 domain : Optional[str]
303 The domain for which to de-register backends. ``None`` means
304 de-register for all domains.
305 registered : bool
306 Whether or not to clear registered backends. See :obj:`register_backend`.
307 globals : bool
308 Whether or not to clear global backends. See :obj:`set_global_backend`.
310 See Also
311 --------
312 register_backend : Register a backend globally.
313 set_global_backend : Set a global backend.
314 """
315 _uarray.clear_backends(domain, registered, globals)
318class Dispatchable:
319 """
320 A utility class which marks an argument with a specific dispatch type.
323 Attributes
324 ----------
325 value
326 The value of the Dispatchable.
328 type
329 The type of the Dispatchable.
331 Examples
332 --------
333 >>> x = Dispatchable(1, str)
334 >>> x
335 <Dispatchable: type=<class 'str'>, value=1>
337 See Also
338 --------
339 all_of_type
340 Marks all unmarked parameters of a function.
342 mark_as
343 Allows one to create a utility function to mark as a given type.
344 """
346 def __init__(self, value, dispatch_type, coercible=True):
347 self.value = value
348 self.type = dispatch_type
349 self.coercible = coercible
351 def __getitem__(self, index):
352 return (self.type, self.value)[index]
354 def __str__(self):
355 return "<{0}: type={1!r}, value={2!r}>".format(
356 type(self).__name__, self.type, self.value
357 )
359 __repr__ = __str__
362def mark_as(dispatch_type):
363 """
364 Creates a utility function to mark something as a specific type.
366 Examples
367 --------
368 >>> mark_int = mark_as(int)
369 >>> mark_int(1)
370 <Dispatchable: type=<class 'int'>, value=1>
371 """
372 return functools.partial(Dispatchable, dispatch_type=dispatch_type)
375def all_of_type(arg_type):
376 """
377 Marks all unmarked arguments as a given type.
379 Examples
380 --------
381 >>> @all_of_type(str)
382 ... def f(a, b):
383 ... return a, Dispatchable(b, int)
384 >>> f('a', 1)
385 (<Dispatchable: type=<class 'str'>, value='a'>, <Dispatchable: type=<class 'int'>, value=1>)
386 """
388 def outer(func):
389 @functools.wraps(func)
390 def inner(*args, **kwargs):
391 extracted_args = func(*args, **kwargs)
392 return tuple(
393 Dispatchable(arg, arg_type)
394 if not isinstance(arg, Dispatchable)
395 else arg
396 for arg in extracted_args
397 )
399 return inner
401 return outer
404def wrap_single_convertor(convert_single):
405 """
406 Wraps a ``__ua_convert__`` defined for a single element to all elements.
407 If any of them return ``NotImplemented``, the operation is assumed to be
408 undefined.
410 Accepts a signature of (value, type, coerce).
411 """
413 @functools.wraps(convert_single)
414 def __ua_convert__(dispatchables, coerce):
415 converted = []
416 for d in dispatchables:
417 c = convert_single(d.value, d.type, coerce and d.coercible)
419 if c is NotImplemented:
420 return NotImplemented
422 converted.append(c)
424 return converted
426 return __ua_convert__