Source code for jax_cosmo.scipy.ode
# this module stores custom ode code
import jax
import jax.numpy as np
[docs]def odeint(fn, y0, t):
"""
My dead-simple rk4 ODE solver. with no custom gradients
"""
def rk4(carry, t):
y, t_prev = carry
h = t - t_prev
k1 = fn(y, t_prev)
k2 = fn(y + h * k1 / 2, t_prev + h / 2)
k3 = fn(y + h * k2 / 2, t_prev + h / 2)
k4 = fn(y + h * k3, t)
y = y + 1.0 / 6.0 * h * (k1 + 2 * k2 + 2 * k3 + k4)
return (y, t), y
(yf, _), y = jax.lax.scan(rk4, (y0, np.array(t[0])), t)
return y