Source code for jax_cosmo.redshift

# Module to define redshift distributions we can differentiate through
from abc import ABC
from abc import abstractmethod

import jax.numpy as np
from jax.tree_util import register_pytree_node_class

from jax_cosmo.jax_utils import container
from jax_cosmo.scipy.integrate import simps

steradian_to_arcmin2 = 11818102.86004228

__all__ = ["smail_nz", "kde_nz", "delta_nz"]


class redshift_distribution(container):
    def __init__(self, *args, gals_per_arcmin2=1.0, zmax=10.0, **kwargs):
        """Initialize the parameters of the redshift distribution"""
        self._norm = None
        self._gals_per_arcmin2 = gals_per_arcmin2
        super(redshift_distribution, self).__init__(*args, zmax=zmax, **kwargs)

    @abstractmethod
    def pz_fn(self, z):
        """Un-normalized n(z) function provided by sub classes"""
        pass

    def __call__(self, z):
        """Computes the normalized n(z)"""
        if self._norm is None:
            self._norm = simps(lambda t: self.pz_fn(t), 0.0, self.config["zmax"], 256)
        return self.pz_fn(z) / self._norm

    @property
    def zmax(self):
        return self.config["zmax"]

    @property
    def gals_per_arcmin2(self):
        """Returns the number density of galaxies in gals/sq arcmin
        TODO: find a better name
        """
        return self._gals_per_arcmin2

    @property
    def gals_per_steradian(self):
        """Returns the number density of galaxies in steradian"""
        return self._gals_per_arcmin2 * steradian_to_arcmin2

    # Operations for flattening/unflattening representation
    def tree_flatten(self):
        children = (self.params, self._gals_per_arcmin2)
        aux_data = self.config
        return (children, aux_data)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        args, gals_per_arcmin2 = children
        return cls(*args, gals_per_arcmin2=gals_per_arcmin2, **aux_data)


[docs]@register_pytree_node_class class smail_nz(redshift_distribution): """Defines a smail distribution with these arguments Parameters: ----------- a: b: z0: gals_per_arcmin2: number of galaxies per sq arcmin """
[docs] def pz_fn(self, z): a, b, z0 = self.params return z**a * np.exp(-((z / z0) ** b))
[docs]@register_pytree_node_class class delta_nz(redshift_distribution): """Defines a single plane redshift distribution with these arguments Parameters: ----------- z0: """ def __init__(self, *args, **kwargs): """Initialize the parameters of the redshift distribution""" super(delta_nz, self).__init__(*args, **kwargs) self._norm = 1.0
[docs] def pz_fn(self, z): z0 = self.params return np.where(z == z0, 1.0, 0)
[docs]@register_pytree_node_class class kde_nz(redshift_distribution): """A redshift distribution based on a KDE estimate of the nz of a given catalog currently uses a Gaussian kernel. TODO: add more if necessary Parameters: ----------- zcat: redshift catalog weights: weight for each galaxy between 0 and 1 Configuration: -------------- bw: Bandwidth for the KDE Example: nz = kde_nz(redshift_catalog, w, bw=0.1) """ def _kernel(self, bw, X, x): """Gaussian kernel for KDE""" return (1.0 / np.sqrt(2 * np.pi) / bw) * np.exp( -((X - x) ** 2) / (bw**2 * 2.0) )
[docs] def pz_fn(self, z): # Extract parameters zcat, weight = self.params[:2] w = np.atleast_1d(weight) q = np.sum(w) X = np.expand_dims(zcat, axis=-1) k = self._kernel(self.config["bw"], X, z) return np.dot(k.T, w) / (q)
@register_pytree_node_class class systematic_shift(redshift_distribution): """Implements a systematic shift in a redshift distribution TODO: Find a better name for this Arguments: redshift_distribution mean_bias """ def pz_fn(self, z): parent_pz, bias = self.params[:2] return parent_pz.pz_fn(np.clip(z - bias, 0))