jax_cosmo.jax_utils module

class jax_cosmo.jax_utils.container(*args, **kwargs)[source]

Bases: object

Generic structure to trace a parameterized function

Paramters for the object, i.e. things that need to be traced for autodiff are stored as a list in self.params Configuration arguments, i.e. static things that do not need to be traced are stored as a dictionary in self.config This is for things like flags or type of PS or things like that.

tree_flatten()[source]
classmethod tree_unflatten(aux_data, children)[source]