Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/scipy/integrate/_quad_vec.py : 13%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import sys
2import copy
3import heapq
4import collections
5import functools
7import numpy as np
9from scipy._lib._util import MapWrapper
12class LRUDict(collections.OrderedDict):
13 def __init__(self, max_size):
14 self.__max_size = max_size
16 def __setitem__(self, key, value):
17 existing_key = (key in self)
18 super(LRUDict, self).__setitem__(key, value)
19 if existing_key:
20 self.move_to_end(key)
21 elif len(self) > self.__max_size:
22 self.popitem(last=False)
24 def update(self, other):
25 # Not needed below
26 raise NotImplementedError()
29class SemiInfiniteFunc(object):
30 """
31 Argument transform from (start, +-oo) to (0, 1)
32 """
33 def __init__(self, func, start, infty):
34 self._func = func
35 self._start = start
36 self._sgn = -1 if infty < 0 else 1
38 # Overflow threshold for the 1/t**2 factor
39 self._tmin = sys.float_info.min**0.5
41 def get_t(self, x):
42 z = self._sgn * (x - self._start) + 1
43 if z == 0:
44 # Can happen only if point not in range
45 return np.inf
46 return 1 / z
48 def __call__(self, t):
49 if t < self._tmin:
50 return 0.0
51 else:
52 x = self._start + self._sgn * (1 - t) / t
53 f = self._func(x)
54 return self._sgn * (f / t) / t
57class DoubleInfiniteFunc(object):
58 """
59 Argument transform from (-oo, oo) to (-1, 1)
60 """
61 def __init__(self, func):
62 self._func = func
64 # Overflow threshold for the 1/t**2 factor
65 self._tmin = sys.float_info.min**0.5
67 def get_t(self, x):
68 s = -1 if x < 0 else 1
69 return s / (abs(x) + 1)
71 def __call__(self, t):
72 if abs(t) < self._tmin:
73 return 0.0
74 else:
75 x = (1 - abs(t)) / t
76 f = self._func(x)
77 return (f / t) / t
80def _max_norm(x):
81 return np.amax(abs(x))
84def _get_sizeof(obj):
85 try:
86 return sys.getsizeof(obj)
87 except TypeError:
88 # occurs on pypy
89 if hasattr(obj, '__sizeof__'):
90 return int(obj.__sizeof__())
91 return 64
94class _Bunch(object):
95 def __init__(self, **kwargs):
96 self.__keys = kwargs.keys()
97 self.__dict__.update(**kwargs)
99 def __repr__(self):
100 return "_Bunch({})".format(", ".join("{}={}".format(k, repr(self.__dict__[k]))
101 for k in self.__keys))
104def quad_vec(f, a, b, epsabs=1e-200, epsrel=1e-8, norm='2', cache_size=100e6, limit=10000,
105 workers=1, points=None, quadrature=None, full_output=False):
106 r"""Adaptive integration of a vector-valued function.
108 Parameters
109 ----------
110 f : callable
111 Vector-valued function f(x) to integrate.
112 a : float
113 Initial point.
114 b : float
115 Final point.
116 epsabs : float, optional
117 Absolute tolerance.
118 epsrel : float, optional
119 Relative tolerance.
120 norm : {'max', '2'}, optional
121 Vector norm to use for error estimation.
122 cache_size : int, optional
123 Number of bytes to use for memoization.
124 workers : int or map-like callable, optional
125 If `workers` is an integer, part of the computation is done in
126 parallel subdivided to this many tasks (using
127 :class:`python:multiprocessing.pool.Pool`).
128 Supply `-1` to use all cores available to the Process.
129 Alternatively, supply a map-like callable, such as
130 :meth:`python:multiprocessing.pool.Pool.map` for evaluating the
131 population in parallel.
132 This evaluation is carried out as ``workers(func, iterable)``.
133 points : list, optional
134 List of additional breakpoints.
135 quadrature : {'gk21', 'gk15', 'trapz'}, optional
136 Quadrature rule to use on subintervals.
137 Options: 'gk21' (Gauss-Kronrod 21-point rule),
138 'gk15' (Gauss-Kronrod 15-point rule),
139 'trapz' (composite trapezoid rule).
140 Default: 'gk21' for finite intervals and 'gk15' for (semi-)infinite
141 full_output : bool, optional
142 Return an additional ``info`` dictionary.
144 Returns
145 -------
146 res : {float, array-like}
147 Estimate for the result
148 err : float
149 Error estimate for the result in the given norm
150 info : dict
151 Returned only when ``full_output=True``.
152 Info dictionary. Is an object with the attributes:
154 success : bool
155 Whether integration reached target precision.
156 status : int
157 Indicator for convergence, success (0),
158 failure (1), and failure due to rounding error (2).
159 neval : int
160 Number of function evaluations.
161 intervals : ndarray, shape (num_intervals, 2)
162 Start and end points of subdivision intervals.
163 integrals : ndarray, shape (num_intervals, ...)
164 Integral for each interval.
165 Note that at most ``cache_size`` values are recorded,
166 and the array may contains *nan* for missing items.
167 errors : ndarray, shape (num_intervals,)
168 Estimated integration error for each interval.
170 Notes
171 -----
172 The algorithm mainly follows the implementation of QUADPACK's
173 DQAG* algorithms, implementing global error control and adaptive
174 subdivision.
176 The algorithm here has some differences to the QUADPACK approach:
178 Instead of subdividing one interval at a time, the algorithm
179 subdivides N intervals with largest errors at once. This enables
180 (partial) parallelization of the integration.
182 The logic of subdividing "next largest" intervals first is then
183 not implemented, and we rely on the above extension to avoid
184 concentrating on "small" intervals only.
186 The Wynn epsilon table extrapolation is not used (QUADPACK uses it
187 for infinite intervals). This is because the algorithm here is
188 supposed to work on vector-valued functions, in an user-specified
189 norm, and the extension of the epsilon algorithm to this case does
190 not appear to be widely agreed. For max-norm, using elementwise
191 Wynn epsilon could be possible, but we do not do this here with
192 the hope that the epsilon extrapolation is mainly useful in
193 special cases.
195 References
196 ----------
197 [1] R. Piessens, E. de Doncker, QUADPACK (1983).
199 Examples
200 --------
201 We can compute integrations of a vector-valued function:
203 >>> from scipy.integrate import quad_vec
204 >>> import matplotlib.pyplot as plt
205 >>> alpha = np.linspace(0.0, 2.0, num=30)
206 >>> f = lambda x: x**alpha
207 >>> x0, x1 = 0, 2
208 >>> y, err = quad_vec(f, x0, x1)
209 >>> plt.plot(alpha, y)
210 >>> plt.xlabel(r"$\alpha$")
211 >>> plt.ylabel(r"$\int_{0}^{2} x^\alpha dx$")
212 >>> plt.show()
214 """
215 a = float(a)
216 b = float(b)
218 # Use simple transformations to deal with integrals over infinite
219 # intervals.
220 kwargs = dict(epsabs=epsabs,
221 epsrel=epsrel,
222 norm=norm,
223 cache_size=cache_size,
224 limit=limit,
225 workers=workers,
226 points=points,
227 quadrature='gk15' if quadrature is None else quadrature,
228 full_output=full_output)
229 if np.isfinite(a) and np.isinf(b):
230 f2 = SemiInfiniteFunc(f, start=a, infty=b)
231 if points is not None:
232 kwargs['points'] = tuple(f2.get_t(xp) for xp in points)
233 return quad_vec(f2, 0, 1, **kwargs)
234 elif np.isfinite(b) and np.isinf(a):
235 f2 = SemiInfiniteFunc(f, start=b, infty=a)
236 if points is not None:
237 kwargs['points'] = tuple(f2.get_t(xp) for xp in points)
238 res = quad_vec(f2, 0, 1, **kwargs)
239 return (-res[0],) + res[1:]
240 elif np.isinf(a) and np.isinf(b):
241 sgn = -1 if b < a else 1
243 # NB. explicitly split integral at t=0, which separates
244 # the positive and negative sides
245 f2 = DoubleInfiniteFunc(f)
246 if points is not None:
247 kwargs['points'] = (0,) + tuple(f2.get_t(xp) for xp in points)
248 else:
249 kwargs['points'] = (0,)
251 if a != b:
252 res = quad_vec(f2, -1, 1, **kwargs)
253 else:
254 res = quad_vec(f2, 1, 1, **kwargs)
256 return (res[0]*sgn,) + res[1:]
257 elif not (np.isfinite(a) and np.isfinite(b)):
258 raise ValueError("invalid integration bounds a={}, b={}".format(a, b))
260 norm_funcs = {
261 None: _max_norm,
262 'max': _max_norm,
263 '2': np.linalg.norm
264 }
265 if callable(norm):
266 norm_func = norm
267 else:
268 norm_func = norm_funcs[norm]
270 mapwrapper = MapWrapper(workers)
272 parallel_count = 128
273 min_intervals = 2
275 try:
276 _quadrature = {None: _quadrature_gk21,
277 'gk21': _quadrature_gk21,
278 'gk15': _quadrature_gk15,
279 'trapz': _quadrature_trapz}[quadrature]
280 except KeyError:
281 raise ValueError("unknown quadrature {!r}".format(quadrature))
283 # Initial interval set
284 if points is None:
285 initial_intervals = [(a, b)]
286 else:
287 prev = a
288 initial_intervals = []
289 for p in sorted(points):
290 p = float(p)
291 if not (a < p < b) or p == prev:
292 continue
293 initial_intervals.append((prev, p))
294 prev = p
295 initial_intervals.append((prev, b))
297 global_integral = None
298 global_error = None
299 rounding_error = None
300 interval_cache = None
301 intervals = []
302 neval = 0
304 for x1, x2 in initial_intervals:
305 ig, err, rnd = _quadrature(x1, x2, f, norm_func)
306 neval += _quadrature.num_eval
308 if global_integral is None:
309 if isinstance(ig, (float, complex)):
310 # Specialize for scalars
311 if norm_func in (_max_norm, np.linalg.norm):
312 norm_func = abs
314 global_integral = ig
315 global_error = float(err)
316 rounding_error = float(rnd)
318 cache_count = cache_size // _get_sizeof(ig)
319 interval_cache = LRUDict(cache_count)
320 else:
321 global_integral += ig
322 global_error += err
323 rounding_error += rnd
325 interval_cache[(x1, x2)] = copy.copy(ig)
326 intervals.append((-err, x1, x2))
328 heapq.heapify(intervals)
330 CONVERGED = 0
331 NOT_CONVERGED = 1
332 ROUNDING_ERROR = 2
333 NOT_A_NUMBER = 3
335 status_msg = {
336 CONVERGED: "Target precision reached.",
337 NOT_CONVERGED: "Target precision not reached.",
338 ROUNDING_ERROR: "Target precision could not be reached due to rounding error.",
339 NOT_A_NUMBER: "Non-finite values encountered."
340 }
342 # Process intervals
343 with mapwrapper:
344 ier = NOT_CONVERGED
346 while intervals and len(intervals) < limit:
347 # Select intervals with largest errors for subdivision
348 tol = max(epsabs, epsrel*norm_func(global_integral))
350 to_process = []
351 err_sum = 0
353 for j in range(parallel_count):
354 if not intervals:
355 break
357 if j > 0 and err_sum > global_error - tol/8:
358 # avoid unnecessary parallel splitting
359 break
361 interval = heapq.heappop(intervals)
363 neg_old_err, a, b = interval
364 old_int = interval_cache.pop((a, b), None)
365 to_process.append(((-neg_old_err, a, b, old_int), f, norm_func, _quadrature))
366 err_sum += -neg_old_err
368 # Subdivide intervals
369 for dint, derr, dround_err, subint, dneval in mapwrapper(_subdivide_interval, to_process):
370 neval += dneval
371 global_integral += dint
372 global_error += derr
373 rounding_error += dround_err
374 for x in subint:
375 x1, x2, ig, err = x
376 interval_cache[(x1, x2)] = ig
377 heapq.heappush(intervals, (-err, x1, x2))
379 # Termination check
380 if len(intervals) >= min_intervals:
381 tol = max(epsabs, epsrel*norm_func(global_integral))
382 if global_error < tol/8:
383 ier = CONVERGED
384 break
385 if global_error < rounding_error:
386 ier = ROUNDING_ERROR
387 break
389 if not (np.isfinite(global_error) and np.isfinite(rounding_error)):
390 ier = NOT_A_NUMBER
391 break
393 res = global_integral
394 err = global_error + rounding_error
396 if full_output:
397 res_arr = np.asarray(res)
398 dummy = np.full(res_arr.shape, np.nan, dtype=res_arr.dtype)
399 integrals = np.array([interval_cache.get((z[1], z[2]), dummy)
400 for z in intervals], dtype=res_arr.dtype)
401 errors = np.array([-z[0] for z in intervals])
402 intervals = np.array([[z[1], z[2]] for z in intervals])
404 info = _Bunch(neval=neval,
405 success=(ier == CONVERGED),
406 status=ier,
407 message=status_msg[ier],
408 intervals=intervals,
409 integrals=integrals,
410 errors=errors)
411 return (res, err, info)
412 else:
413 return (res, err)
416def _subdivide_interval(args):
417 interval, f, norm_func, _quadrature = args
418 old_err, a, b, old_int = interval
420 c = 0.5 * (a + b)
422 # Left-hand side
423 if getattr(_quadrature, 'cache_size', 0) > 0:
424 f = functools.lru_cache(_quadrature.cache_size)(f)
426 s1, err1, round1 = _quadrature(a, c, f, norm_func)
427 dneval = _quadrature.num_eval
428 s2, err2, round2 = _quadrature(c, b, f, norm_func)
429 dneval += _quadrature.num_eval
430 if old_int is None:
431 old_int, _, _ = _quadrature(a, b, f, norm_func)
432 dneval += _quadrature.num_eval
434 if getattr(_quadrature, 'cache_size', 0) > 0:
435 dneval = f.cache_info().misses
437 dint = s1 + s2 - old_int
438 derr = err1 + err2 - old_err
439 dround_err = round1 + round2
441 subintervals = ((a, c, s1, err1), (c, b, s2, err2))
442 return dint, derr, dround_err, subintervals, dneval
445def _quadrature_trapz(x1, x2, f, norm_func):
446 """
447 Composite trapezoid quadrature
448 """
449 x3 = 0.5*(x1 + x2)
450 f1 = f(x1)
451 f2 = f(x2)
452 f3 = f(x3)
454 s2 = 0.25 * (x2 - x1) * (f1 + 2*f3 + f2)
456 round_err = 0.25 * abs(x2 - x1) * (float(norm_func(f1))
457 + 2*float(norm_func(f3))
458 + float(norm_func(f2))) * 2e-16
460 s1 = 0.5 * (x2 - x1) * (f1 + f2)
461 err = 1/3 * float(norm_func(s1 - s2))
462 return s2, err, round_err
465_quadrature_trapz.cache_size = 3 * 3
466_quadrature_trapz.num_eval = 3
469def _quadrature_gk(a, b, f, norm_func, x, w, v):
470 """
471 Generic Gauss-Kronrod quadrature
472 """
474 fv = [0.0]*len(x)
476 c = 0.5 * (a + b)
477 h = 0.5 * (b - a)
479 # Gauss-Kronrod
480 s_k = 0.0
481 s_k_abs = 0.0
482 for i in range(len(x)):
483 ff = f(c + h*x[i])
484 fv[i] = ff
486 vv = v[i]
488 # \int f(x)
489 s_k += vv * ff
490 # \int |f(x)|
491 s_k_abs += vv * abs(ff)
493 # Gauss
494 s_g = 0.0
495 for i in range(len(w)):
496 s_g += w[i] * fv[2*i + 1]
498 # Quadrature of abs-deviation from average
499 s_k_dabs = 0.0
500 y0 = s_k / 2.0
501 for i in range(len(x)):
502 # \int |f(x) - y0|
503 s_k_dabs += v[i] * abs(fv[i] - y0)
505 # Use similar error estimation as quadpack
506 err = float(norm_func((s_k - s_g) * h))
507 dabs = float(norm_func(s_k_dabs * h))
508 if dabs != 0 and err != 0:
509 err = dabs * min(1.0, (200 * err / dabs)**1.5)
511 eps = sys.float_info.epsilon
512 round_err = float(norm_func(50 * eps * h * s_k_abs))
514 if round_err > sys.float_info.min:
515 err = max(err, round_err)
517 return h * s_k, err, round_err
520def _quadrature_gk21(a, b, f, norm_func):
521 """
522 Gauss-Kronrod 21 quadrature with error estimate
523 """
524 # Gauss-Kronrod points
525 x = (0.995657163025808080735527280689003,
526 0.973906528517171720077964012084452,
527 0.930157491355708226001207180059508,
528 0.865063366688984510732096688423493,
529 0.780817726586416897063717578345042,
530 0.679409568299024406234327365114874,
531 0.562757134668604683339000099272694,
532 0.433395394129247190799265943165784,
533 0.294392862701460198131126603103866,
534 0.148874338981631210884826001129720,
535 0,
536 -0.148874338981631210884826001129720,
537 -0.294392862701460198131126603103866,
538 -0.433395394129247190799265943165784,
539 -0.562757134668604683339000099272694,
540 -0.679409568299024406234327365114874,
541 -0.780817726586416897063717578345042,
542 -0.865063366688984510732096688423493,
543 -0.930157491355708226001207180059508,
544 -0.973906528517171720077964012084452,
545 -0.995657163025808080735527280689003)
547 # 10-point weights
548 w = (0.066671344308688137593568809893332,
549 0.149451349150580593145776339657697,
550 0.219086362515982043995534934228163,
551 0.269266719309996355091226921569469,
552 0.295524224714752870173892994651338,
553 0.295524224714752870173892994651338,
554 0.269266719309996355091226921569469,
555 0.219086362515982043995534934228163,
556 0.149451349150580593145776339657697,
557 0.066671344308688137593568809893332)
559 # 21-point weights
560 v = (0.011694638867371874278064396062192,
561 0.032558162307964727478818972459390,
562 0.054755896574351996031381300244580,
563 0.075039674810919952767043140916190,
564 0.093125454583697605535065465083366,
565 0.109387158802297641899210590325805,
566 0.123491976262065851077958109831074,
567 0.134709217311473325928054001771707,
568 0.142775938577060080797094273138717,
569 0.147739104901338491374841515972068,
570 0.149445554002916905664936468389821,
571 0.147739104901338491374841515972068,
572 0.142775938577060080797094273138717,
573 0.134709217311473325928054001771707,
574 0.123491976262065851077958109831074,
575 0.109387158802297641899210590325805,
576 0.093125454583697605535065465083366,
577 0.075039674810919952767043140916190,
578 0.054755896574351996031381300244580,
579 0.032558162307964727478818972459390,
580 0.011694638867371874278064396062192)
582 return _quadrature_gk(a, b, f, norm_func, x, w, v)
585_quadrature_gk21.num_eval = 21
588def _quadrature_gk15(a, b, f, norm_func):
589 """
590 Gauss-Kronrod 15 quadrature with error estimate
591 """
592 # Gauss-Kronrod points
593 x = (0.991455371120812639206854697526329,
594 0.949107912342758524526189684047851,
595 0.864864423359769072789712788640926,
596 0.741531185599394439863864773280788,
597 0.586087235467691130294144838258730,
598 0.405845151377397166906606412076961,
599 0.207784955007898467600689403773245,
600 0.000000000000000000000000000000000,
601 -0.207784955007898467600689403773245,
602 -0.405845151377397166906606412076961,
603 -0.586087235467691130294144838258730,
604 -0.741531185599394439863864773280788,
605 -0.864864423359769072789712788640926,
606 -0.949107912342758524526189684047851,
607 -0.991455371120812639206854697526329)
609 # 7-point weights
610 w = (0.129484966168869693270611432679082,
611 0.279705391489276667901467771423780,
612 0.381830050505118944950369775488975,
613 0.417959183673469387755102040816327,
614 0.381830050505118944950369775488975,
615 0.279705391489276667901467771423780,
616 0.129484966168869693270611432679082)
618 # 15-point weights
619 v = (0.022935322010529224963732008058970,
620 0.063092092629978553290700663189204,
621 0.104790010322250183839876322541518,
622 0.140653259715525918745189590510238,
623 0.169004726639267902826583426598550,
624 0.190350578064785409913256402421014,
625 0.204432940075298892414161999234649,
626 0.209482141084727828012999174891714,
627 0.204432940075298892414161999234649,
628 0.190350578064785409913256402421014,
629 0.169004726639267902826583426598550,
630 0.140653259715525918745189590510238,
631 0.104790010322250183839876322541518,
632 0.063092092629978553290700663189204,
633 0.022935322010529224963732008058970)
635 return _quadrature_gk(a, b, f, norm_func, x, w, v)
638_quadrature_gk15.num_eval = 15