from functools import partial
import jax
import jax.numpy as np
from jax import jit
from jax import vmap
__all__ = ["romb", "simps"]
# Romberg quadratures for numeric integration.
#
# Written by Scott M. Ransom <ransom@cfa.harvard.edu>
# last revision: 14 Nov 98
#
# Cosmetic changes by Konrad Hinsen <hinsen@cnrs-orleans.fr>
# last revision: 1999-7-21
#
# Adapted to scipy by Travis Oliphant <oliphant.travis@ieee.org>
# last revision: Dec 2001
def _difftrap1(function, interval):
"""
Perform part of the trapezoidal rule to integrate a function.
Assume that we had called difftrap with all lower powers-of-2
starting with 1. Calling difftrap only returns the summation
of the new ordinates. It does _not_ multiply by the width
of the trapezoids. This must be performed by the caller.
'function' is the function to evaluate (must accept vector arguments).
'interval' is a sequence with lower and upper limits
of integration.
'numtraps' is the number of trapezoids to use (must be a
power-of-2).
"""
return 0.5 * (function(interval[0]) + function(interval[1]))
def _difftrapn(function, interval, numtraps):
"""
Perform part of the trapezoidal rule to integrate a function.
Assume that we had called difftrap with all lower powers-of-2
starting with 1. Calling difftrap only returns the summation
of the new ordinates. It does _not_ multiply by the width
of the trapezoids. This must be performed by the caller.
'function' is the function to evaluate (must accept vector arguments).
'interval' is a sequence with lower and upper limits
of integration.
'numtraps' is the number of trapezoids to use (must be a
power-of-2).
"""
numtosum = numtraps // 2
h = (1.0 * interval[1] - 1.0 * interval[0]) / numtosum
lox = interval[0] + 0.5 * h
points = lox + h * np.arange(0, numtosum)
s = np.sum(function(points))
return s
def _romberg_diff(b, c, k):
"""
Compute the differences for the Romberg quadrature corrections.
See Forman Acton's "Real Computing Made Real," p 143.
"""
tmp = 4.0**k
return (tmp * c - b) / (tmp - 1.0)
[docs]def romb(function, a, b, args=(), divmax=6, return_error=False):
"""
Romberg integration of a callable function or method.
Returns the integral of `function` (a function of one variable)
over the interval (`a`, `b`).
If `show` is 1, the triangular array of the intermediate results
will be printed. If `vec_func` is True (default is False), then
`function` is assumed to support vector arguments.
Parameters
----------
function : callable
Function to be integrated.
a : float
Lower limit of integration.
b : float
Upper limit of integration.
Returns
-------
results : float
Result of the integration.
Other Parameters
----------------
args : tuple, optional
Extra arguments to pass to function. Each element of `args` will
be passed as a single argument to `func`. Default is to pass no
extra arguments.
divmax : int, optional
Maximum order of extrapolation. Default is 10.
See Also
--------
fixed_quad : Fixed-order Gaussian quadrature.
quad : Adaptive quadrature using QUADPACK.
dblquad : Double integrals.
tplquad : Triple integrals.
romb : Integrators for sampled data.
simps : Integrators for sampled data.
cumtrapz : Cumulative integration for sampled data.
ode : ODE integrator.
odeint : ODE integrator.
References
----------
.. [1] 'Romberg's method' http://en.wikipedia.org/wiki/Romberg%27s_method
Examples
--------
Integrate a gaussian from 0 to 1 and compare to the error function.
>>> from scipy import integrate
>>> from scipy.special import erf
>>> gaussian = lambda x: 1/np.sqrt(np.pi) * np.exp(-x**2)
>>> result = integrate.romberg(gaussian, 0, 1, show=True)
Romberg integration of <function vfunc at ...> from [0, 1]
::
Steps StepSize Results
1 1.000000 0.385872
2 0.500000 0.412631 0.421551
4 0.250000 0.419184 0.421368 0.421356
8 0.125000 0.420810 0.421352 0.421350 0.421350
16 0.062500 0.421215 0.421350 0.421350 0.421350 0.421350
32 0.031250 0.421317 0.421350 0.421350 0.421350 0.421350 0.421350
The final result is 0.421350396475 after 33 function evaluations.
>>> print("%g %g" % (2*result, erf(1)))
0.842701 0.842701
"""
vfunc = jit(lambda x: function(x, *args))
n = 1
interval = [a, b]
intrange = b - a
ordsum = _difftrap1(vfunc, interval)
result = intrange * ordsum
state = np.repeat(np.atleast_1d(result), divmax + 1, axis=-1)
err = np.inf
def scan_fn(carry, y):
x, k = carry
x = _romberg_diff(y, x, k + 1)
return (x, k + 1), x
for i in range(1, divmax + 1):
n = 2**i
ordsum = ordsum + _difftrapn(vfunc, interval, n)
x = intrange * ordsum / n
_, new_state = jax.lax.scan(scan_fn, (x, 0), state[:-1])
new_state = np.concatenate([np.atleast_1d(x), new_state])
err = np.abs(state[i - 1] - new_state[i])
state = new_state
if return_error:
return state[i], err
else:
return state[i]
[docs]def simps(f, a, b, N=128):
"""Approximate the integral of f(x) from a to b by Simpson's rule.
Simpson's rule approximates the integral \int_a^b f(x) dx by the sum:
(dx/3) \sum_{k=1}^{N/2} (f(x_{2i-2} + 4f(x_{2i-1}) + f(x_{2i}))
where x_i = a + i*dx and dx = (b - a)/N.
Parameters
----------
f : function
Vectorized function of a single variable
a , b : numbers
Interval of integration [a,b]
N : (even) integer
Number of subintervals of [a,b]
Returns
-------
float
Approximation of the integral of f(x) from a to b using
Simpson's rule with N subintervals of equal length.
Examples
--------
>>> simps(lambda x : 3*x**2,0,1,10)
1.0
Notes:
------
Stolen from: https://www.math.ubc.ca/~pwalls/math-python/integration/simpsons-rule/
"""
if N % 2 == 1:
raise ValueError("N must be an even integer.")
dx = (b - a) / N
x = np.linspace(a, b, N + 1)
y = f(x)
S = dx / 3 * np.sum(y[0:-1:2] + 4 * y[1::2] + y[2::2], axis=0)
return S