Source code for jax_cosmo.jax_utils

import jax.numpy as np


[docs]class container(object): """ Generic structure to trace a parameterized function Paramters for the object, i.e. things that need to be traced for autodiff are stored as a list in self.params Configuration arguments, i.e. static things that do not need to be traced are stored as a dictionary in self.config This is for things like flags or type of PS or things like that. """ def __init__(self, *args, **kwargs): self.params = args self.config = kwargs def __repr__(self): return str(self.params) # Operations for flattening/unflattening representation
[docs] def tree_flatten(self): children = self.params aux_data = self.config return (children, aux_data)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): return cls(*children, **aux_data)