# This module contains some missing ops from jax
import functools
import jax.numpy as np
from jax import vmap
from jax.numpy import array
from jax.numpy import concatenate
from jax.numpy import ones
from jax.numpy import zeros
from jax.tree_util import register_pytree_node_class
__all__ = ["interp"]
[docs]@functools.partial(vmap, in_axes=(0, None, None))
def interp(x, xp, fp):
"""
Simple equivalent of np.interp that compute a linear interpolation.
We are not doing any checks, so make sure your query points are lying
inside the array.
TODO: Implement proper interpolation!
x, xp, fp need to be 1d arrays
"""
# First we find the nearest neighbour
ind = np.argmin((x - xp) ** 2)
# Perform linear interpolation
ind = np.clip(ind, 1, len(xp) - 2)
xi = xp[ind]
# Figure out if we are on the right or the left of nearest
s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64)
a = (fp[ind + np.copysign(1, s).astype(np.int64)] - fp[ind]) / (
xp[ind + np.copysign(1, s).astype(np.int64)] - xp[ind]
)
b = fp[ind] - a * xp[ind]
return a * x + b
@register_pytree_node_class
class InterpolatedUnivariateSpline(object):
def __init__(self, x, y, k=3, endpoints="not-a-knot", coefficients=None):
"""JAX implementation of kth-order spline interpolation.
This class aims to reproduce scipy's InterpolatedUnivariateSpline
functionality using JAX. Not all of the original class's features
have been implemented yet, notably
- `w` : no weights are used in the spline fitting.
- `bbox` : we assume the boundary to always be [x[0], x[-1]].
- `ext` : extrapolation is always active, i.e., `ext` = 0.
- `k` : orders `k` > 3 are not available.
- `check_finite` : no such check is performed.
(The relevant lines from the original docstring have been included
in the following.)
Fits a spline y = spl(x) of degree `k` to the provided `x`, `y` data.
Spline function passes through all provided points. Equivalent to
`UnivariateSpline` with s = 0.
Parameters
----------
x : (N,) array_like
Input dimension of data points -- must be strictly increasing
y : (N,) array_like
input dimension of data points
k : int, optional
Degree of the smoothing spline. Must be 1 <= `k` <= 3.
endpoints : str, optional, one of {'natural', 'not-a-knot'}
Endpoint condition for cubic splines, i.e., `k` = 3.
'natural' endpoints enforce a vanishing second derivative
of the spline at the two endpoints, while 'not-a-knot'
ensures that the third derivatives are equal for the two
left-most `x` of the domain, as well as for the two
right-most `x`. The original scipy implementation uses
'not-a-knot'.
coefficients: list, optional
Precomputed parameters for spline interpolation. Shouldn't be set
manually.
See Also
--------
UnivariateSpline : Superclass -- allows knots to be selected by a
smoothing condition
LSQUnivariateSpline : spline for which knots are user-selected
splrep : An older, non object-oriented wrapping of FITPACK
splev, sproot, splint, spalde
BivariateSpline : A similar class for two-dimensional spline interpolation
Notes
-----
The number of data points must be larger than the spline degree `k`.
The general form of the spline can be written as
f[i](x) = a[i] + b[i](x - x[i]) + c[i](x - x[i])^2 + d[i](x - x[i])^3,
i = 0, ..., n-1,
where d = 0 for `k` = 2, and c = d = 0 for `k` = 1.
The unknown coefficients (a, b, c, d) define a symmetric, diagonal
linear system of equations, Az = s, where z = b for `k` = 1 and `k` = 2,
and z = c for `k` = 3. In each case, the coefficients defining each
spline piece can be expressed in terms of only z[i], z[i+1],
y[i], and y[i+1]. The coefficients are solved for using
`np.linalg.solve` when `k` = 2 and `k` = 3.
"""
# Verify inputs
k = int(k)
assert k in (1, 2, 3), "Order k must be in {1, 2, 3}."
x = np.atleast_1d(x)
y = np.atleast_1d(y)
assert len(x) == len(y), "Input arrays must be the same length."
assert x.ndim == 1 and y.ndim == 1, "Input arrays must be 1D."
n_data = len(x)
# Difference vectors
h = np.diff(x) # x[i+1] - x[i] for i=0,...,n-1
p = np.diff(y) # y[i+1] - y[i]
if coefficients is None:
# Build the linear system of equations depending on k
# (No matrix necessary for k=1)
if k == 1:
assert n_data > 1, "Not enough input points for linear spline."
coefficients = p / h
if k == 2:
assert n_data > 2, "Not enough input points for quadratic spline."
assert endpoints == "not-a-knot" # I have only validated this
# And actually I think it's probably the best choice of border condition
# The knots are actually in between data points
knots = (x[1:] + x[:-1]) / 2.0
# We add 2 artificial knots before and after
knots = np.concatenate(
[
np.array([x[0] - (x[1] - x[0]) / 2.0]),
knots,
np.array([x[-1] + (x[-1] - x[-2]) / 2.0]),
]
)
n = len(knots)
# Compute interval lenghts for these new knots
h = np.diff(knots)
# postition of data point inside the interval
dt = x - knots[:-1]
# Now we build the system natrix
A = np.diag(
np.concatenate(
[
np.ones(1),
(
2 * dt[1:]
- dt[1:] ** 2 / h[1:]
- dt[:-1] ** 2 / h[:-1]
+ h[:-1]
),
np.ones(1),
]
)
)
A += np.diag(
np.concatenate([-np.array([1 + h[0] / h[1]]), dt[1:] ** 2 / h[1:]]),
k=1,
)
A += np.diag(
np.concatenate([np.atleast_1d(h[0] / h[1]), np.zeros(n - 3)]), k=2
)
A += np.diag(
np.concatenate(
[
h[:-1] - 2 * dt[:-1] + dt[:-1] ** 2 / h[:-1],
-np.array([1 + h[-1] / h[-2]]),
]
),
k=-1,
)
A += np.diag(
np.concatenate([np.zeros(n - 3), np.atleast_1d(h[-1] / h[-2])]),
k=-2,
)
# And now we build the RHS vector
s = np.concatenate([np.zeros(1), 2 * p, np.zeros(1)])
# Compute spline coefficients by solving the system
coefficients = np.linalg.solve(A, s)
if k == 3:
assert n_data > 3, "Not enough input points for cubic spline."
if endpoints not in ("natural", "not-a-knot"):
print("Warning : endpoints not recognized. Using natural.")
endpoints = "natural"
# Special values for the first and last equations
zero = array([0.0])
one = array([1.0])
A00 = one if endpoints == "natural" else array([h[1]])
A01 = zero if endpoints == "natural" else array([-(h[0] + h[1])])
A02 = zero if endpoints == "natural" else array([h[0]])
ANN = one if endpoints == "natural" else array([h[-2]])
AN1 = (
-one if endpoints == "natural" else array([-(h[-2] + h[-1])])
) # A[N, N-1]
AN2 = zero if endpoints == "natural" else array([h[-1]]) # A[N, N-2]
# Construct the tri-diagonal matrix A
A = np.diag(concatenate((A00, 2 * (h[:-1] + h[1:]), ANN)))
upper_diag1 = np.diag(concatenate((A01, h[1:])), k=1)
upper_diag2 = np.diag(concatenate((A02, zeros(n_data - 3))), k=2)
lower_diag1 = np.diag(concatenate((h[:-1], AN1)), k=-1)
lower_diag2 = np.diag(concatenate((zeros(n_data - 3), AN2)), k=-2)
A += upper_diag1 + upper_diag2 + lower_diag1 + lower_diag2
# Construct RHS vector s
center = 3 * (p[1:] / h[1:] - p[:-1] / h[:-1])
s = concatenate((zero, center, zero))
# Compute spline coefficients by solving the system
coefficients = np.linalg.solve(A, s)
# Saving spline parameters for evaluation later
self.k = k
self._x = x
self._y = y
self._coefficients = coefficients
# Operations for flattening/unflattening representation
def tree_flatten(self):
children = (self._x, self._y, self._coefficients)
aux_data = {"endpoints": self._endpoints, "k": self.k}
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
x, y, coefficients = children
return cls(x, y, coefficients=coefficients, **aux_data)
def __call__(self, x):
"""Evaluation of the spline.
Notes
-----
Values are extrapolated if x is outside of the original domain
of knots. If x is less than the left-most knot, the spline piece
f[0] is used for the evaluation; similarly for x beyond the
right-most point.
"""
if self.k == 1:
t, a, b = self._compute_coeffs(x)
result = a + b * t
if self.k == 2:
t, a, b, c = self._compute_coeffs(x)
result = a + b * t + c * t**2
if self.k == 3:
t, a, b, c, d = self._compute_coeffs(x)
result = a + b * t + c * t**2 + d * t**3
return result
def _compute_coeffs(self, xs):
"""Compute the spline coefficients for a given x."""
# Retrieve parameters
x, y, coefficients = self._x, self._y, self._coefficients
# In case of quadratic, we redefine the knots
if self.k == 2:
knots = (x[1:] + x[:-1]) / 2.0
# We add 2 artificial knots before and after
knots = np.concatenate(
[
np.array([x[0] - (x[1] - x[0]) / 2.0]),
knots,
np.array([x[-1] + (x[-1] - x[-2]) / 2.0]),
]
)
else:
knots = x
# Determine the interval that x lies in
ind = np.digitize(xs, knots) - 1
# Include the right endpoint in spline piece C[m-1]
ind = np.clip(ind, 0, len(knots) - 2)
t = xs - knots[ind]
h = np.diff(knots)[ind]
if self.k == 1:
a = y[ind]
result = (t, a, coefficients[ind])
if self.k == 2:
dt = (x - knots[:-1])[ind]
b = coefficients[ind]
b1 = coefficients[ind + 1]
a = y[ind] - b * dt - (b1 - b) * dt**2 / (2 * h)
c = (b1 - b) / (2 * h)
result = (t, a, b, c)
if self.k == 3:
c = coefficients[ind]
c1 = coefficients[ind + 1]
a = y[ind]
a1 = y[ind + 1]
b = (a1 - a) / h - (2 * c + c1) * h / 3.0
d = (c1 - c) / (3 * h)
result = (t, a, b, c, d)
return result
def derivative(self, x, n=1):
"""Analytic nth derivative of the spline.
The spline has derivatives up to its order k.
"""
assert n in range(self.k + 1), "Invalid n."
if n == 0:
result = self.__call__(x)
else:
# Linear
if self.k == 1:
t, a, b = self._compute_coeffs(x)
result = b
# Quadratic
if self.k == 2:
t, a, b, c = self._compute_coeffs(x)
if n == 1:
result = b + 2 * c * t
if n == 2:
result = 2 * c
# Cubic
if self.k == 3:
t, a, b, c, d = self._compute_coeffs(x)
if n == 1:
result = b + 2 * c * t + 3 * d * t**2
if n == 2:
result = 2 * c + 6 * d * t
if n == 3:
result = 6 * d
return result
def antiderivative(self, xs):
"""
Computes the antiderivative of first order of this spline
"""
# Retrieve parameters
x, y, coefficients = self._x, self._y, self._coefficients
# In case of quadratic, we redefine the knots
if self.k == 2:
knots = (x[1:] + x[:-1]) / 2.0
# We add 2 artificial knots before and after
knots = np.concatenate(
[
np.array([x[0] - (x[1] - x[0]) / 2.0]),
knots,
np.array([x[-1] + (x[-1] - x[-2]) / 2.0]),
]
)
else:
knots = x
# Determine the interval that x lies in
ind = np.digitize(xs, knots) - 1
# Include the right endpoint in spline piece C[m-1]
ind = np.clip(ind, 0, len(knots) - 2)
t = xs - knots[ind]
if self.k == 1:
a = y[:-1]
b = coefficients
h = np.diff(knots)
cst = np.concatenate([np.zeros(1), np.cumsum(a * h + b * h**2 / 2)])
return cst[ind] + a[ind] * t + b[ind] * t**2 / 2
if self.k == 2:
h = np.diff(knots)
dt = x - knots[:-1]
b = coefficients[:-1]
b1 = coefficients[1:]
a = y - b * dt - (b1 - b) * dt**2 / (2 * h)
c = (b1 - b) / (2 * h)
cst = np.concatenate(
[np.zeros(1), np.cumsum(a * h + b * h**2 / 2 + c * h**3 / 3)]
)
return cst[ind] + a[ind] * t + b[ind] * t**2 / 2 + c[ind] * t**3 / 3
if self.k == 3:
h = np.diff(knots)
c = coefficients[:-1]
c1 = coefficients[1:]
a = y[:-1]
a1 = y[1:]
b = (a1 - a) / h - (2 * c + c1) * h / 3.0
d = (c1 - c) / (3 * h)
cst = np.concatenate(
[
np.zeros(1),
np.cumsum(a * h + b * h**2 / 2 + c * h**3 / 3 + d * h**4 / 4),
]
)
return (
cst[ind]
+ a[ind] * t
+ b[ind] * t**2 / 2
+ c[ind] * t**3 / 3
+ d[ind] * t**4 / 4
)
def integral(self, a, b):
"""
Compute a definite integral over a piecewise polynomial.
Parameters
----------
a : float
Lower integration bound
b : float
Upper integration bound
Returns
-------
ig : array_like
Definite integral of the piecewise polynomial over [a, b]
"""
# Swap integration bounds if needed
sign = 1
if b < a:
a, b = b, a
sign = -1
xs = np.array([a, b])
return sign * np.diff(self.antiderivative(xs))