# This module defines kernel functions for various tracers
import jax.numpy as np
from jax import jit
from jax import vmap
from jax.tree_util import register_pytree_node_class
import jax_cosmo.background as bkgrd
import jax_cosmo.constants as const
import jax_cosmo.redshift as rds
from jax_cosmo.jax_utils import container
from jax_cosmo.scipy.integrate import simps
from jax_cosmo.utils import a2z
from jax_cosmo.utils import z2a
__all__ = ["WeakLensing", "NumberCounts"]
@jit
def weak_lensing_kernel(cosmo, pzs, z, ell):
"""
Returns a weak lensing kernel
Note: this function handles differently nzs that correspond to extended redshift
distribution, and delta functions.
"""
z = np.atleast_1d(z)
zmax = max([pz.zmax for pz in pzs])
# Retrieve comoving distance corresponding to z
chi = bkgrd.radial_comoving_distance(cosmo, z2a(z))
# Extract the indices of pzs that can be treated as extended distributions,
# and the ones that need to be treated as delta functions.
pzs_extended_idx = [
i for i, pz in enumerate(pzs) if not isinstance(pz, rds.delta_nz)
]
pzs_delta_idx = [i for i, pz in enumerate(pzs) if isinstance(pz, rds.delta_nz)]
# Here we define a permutation that would put all extended pzs at the begining of the list
perm = pzs_extended_idx + pzs_delta_idx
# Compute inverse permutation
inv = np.argsort(np.array(perm, dtype=np.int32))
# Process extended distributions, if any
radial_kernels = []
if len(pzs_extended_idx) > 0:
@vmap
def integrand(z_prime):
chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime))
# Stack the dndz of all redshift bins
dndz = np.stack([pzs[i](z_prime) for i in pzs_extended_idx], axis=0)
return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0)
radial_kernels.append(simps(integrand, z, zmax, 256) * (1.0 + z) * chi)
# Process single plane redshifts if any
if len(pzs_delta_idx) > 0:
@vmap
def integrand_single(z_prime):
chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime))
return np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0)
radial_kernels.append(
integrand_single(np.array([pzs[i].params[0] for i in pzs_delta_idx]))
* (1.0 + z)
* chi
)
# Fusing the results together
radial_kernel = np.concatenate(radial_kernels, axis=0)
# And perfoming inverse permutation to put all the indices where they should be
radial_kernel = radial_kernel[inv]
# Constant term
constant_factor = 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c
# Ell dependent factor
ell_factor = np.sqrt((ell - 1) * (ell) * (ell + 1) * (ell + 2)) / (ell + 0.5) ** 2
return constant_factor * ell_factor * radial_kernel
@jit
def density_kernel(cosmo, pzs, bias, z, ell):
"""
Computes the number counts density kernel
"""
if any(isinstance(pz, rds.delta_nz) for pz in pzs):
raise NotImplementedError(
"Density kernel not properly implemented for delta redshift distributions"
)
# stack the dndz of all redshift bins
dndz = np.stack([pz(z) for pz in pzs], axis=0)
# Compute radial NLA kernel: same as clustering
if isinstance(bias, list):
# This is to handle the case where we get a bin-dependent bias
b = np.stack([b(cosmo, z) for b in bias], axis=0)
else:
b = bias(cosmo, z)
radial_kernel = dndz * b * bkgrd.H(cosmo, z2a(z))
# Normalization,
constant_factor = 1.0
# Ell dependent factor
ell_factor = 1.0
return constant_factor * ell_factor * radial_kernel
@jit
def nla_kernel(cosmo, pzs, bias, z, ell):
"""
Computes the NLA IA kernel
"""
if any(isinstance(pz, rds.delta_nz) for pz in pzs):
raise NotImplementedError(
"NLA kernel not properly implemented for delta redshift distributions"
)
# stack the dndz of all redshift bins
dndz = np.stack([pz(z) for pz in pzs], axis=0)
# Compute radial NLA kernel: same as clustering
if isinstance(bias, list):
# This is to handle the case where we get a bin-dependent bias
b = np.stack([b(cosmo, z) for b in bias], axis=0)
else:
b = bias(cosmo, z)
radial_kernel = dndz * b * bkgrd.H(cosmo, z2a(z))
# Apply common A_IA normalization to the kernel
# Joachimi et al. (2011), arXiv: 1008.3491, Eq. 6.
radial_kernel *= (
-(5e-14 * const.rhocrit) * cosmo.Omega_m / bkgrd.growth_factor(cosmo, z2a(z))
)
# Constant factor
constant_factor = 1.0
# Ell dependent factor
ell_factor = np.sqrt((ell - 1) * (ell) * (ell + 1) * (ell + 2)) / (ell + 0.5) ** 2
return constant_factor * ell_factor * radial_kernel
[docs]@register_pytree_node_class
class WeakLensing(container):
"""
Class representing a weak lensing probe, with a bunch of bins
Parameters:
-----------
redshift_bins: list of nzredshift distributions
ia_bias: (optional) if provided, IA will be added with the NLA model,
either a single bias object or a list of same size as nzs
multiplicative_bias: (optional) adds an (1+m) multiplicative bias, either single
value or list of same length as redshift bins
Configuration:
--------------
sigma_e: intrinsic galaxy ellipticity
"""
def __init__(
self,
redshift_bins,
ia_bias=None,
multiplicative_bias=0.0,
sigma_e=0.26,
**kwargs
):
# Depending on the Configuration we will trace or not the ia_bias in the
# container
if ia_bias is None:
ia_enabled = False
args = (redshift_bins, multiplicative_bias)
else:
ia_enabled = True
args = (redshift_bins, multiplicative_bias, ia_bias)
if "ia_enabled" not in kwargs.keys():
kwargs["ia_enabled"] = ia_enabled
super(WeakLensing, self).__init__(*args, sigma_e=sigma_e, **kwargs)
@property
def n_tracers(self):
"""
Returns the number of tracers for this probe, i.e. redshift bins
"""
# Extract parameters
pzs = self.params[0]
return len(pzs)
@property
def zmax(self):
"""
Returns the maximum redsfhit probed by this probe
"""
# Extract parameters
pzs = self.params[0]
return max([pz.zmax for pz in pzs])
[docs] def kernel(self, cosmo, z, ell):
"""
Compute the radial kernel for all nz bins in this probe.
Returns:
--------
radial_kernel: shape (nbins, nz)
"""
z = np.atleast_1d(z)
# Extract parameters
pzs, m = self.params[:2]
kernel = weak_lensing_kernel(cosmo, pzs, z, ell)
# If IA is enabled, we add the IA kernel
if self.config["ia_enabled"]:
bias = self.params[2]
kernel += nla_kernel(cosmo, pzs, bias, z, ell)
# Applies measurement systematics
if isinstance(m, list):
m = np.expand_dims(np.stack([mi for mi in m], axis=0), 1)
kernel *= 1.0 + m
return kernel
[docs] def noise(self):
"""
Returns the noise power for all redshifts
return: shape [nbins]
"""
# Extract parameters
pzs = self.params[0]
# retrieve number of galaxies in each bins
ngals = np.array([pz.gals_per_steradian for pz in pzs])
if isinstance(self.config["sigma_e"], list):
sigma_e = np.array([s for s in self.config["sigma_e"]])
else:
sigma_e = self.config["sigma_e"]
return sigma_e**2 / ngals
[docs]@register_pytree_node_class
class NumberCounts(container):
"""Class representing a galaxy clustering probe, with a bunch of bins
Parameters:
-----------
redshift_bins: nzredshift distributions
Configuration:
--------------
has_rsd....
"""
def __init__(self, redshift_bins, bias, has_rsd=False, **kwargs):
super(NumberCounts, self).__init__(
redshift_bins, bias, has_rsd=has_rsd, **kwargs
)
@property
def zmax(self):
"""
Returns the maximum redsfhit probed by this probe
"""
# Extract parameters
pzs = self.params[0]
return max([pz.zmax for pz in pzs])
@property
def n_tracers(self):
"""Returns the number of tracers for this probe, i.e. redshift bins"""
# Extract parameters
pzs = self.params[0]
return len(pzs)
[docs] def kernel(self, cosmo, z, ell):
"""Compute the radial kernel for all nz bins in this probe.
Returns:
--------
radial_kernel: shape (nbins, nz)
"""
z = np.atleast_1d(z)
# Extract parameters
pzs, bias = self.params
# Retrieve density kernel
kernel = density_kernel(cosmo, pzs, bias, z, ell)
return kernel
[docs] def noise(self):
"""Returns the noise power for all redshifts
return: shape [nbins]
"""
# Extract parameters
pzs = self.params[0]
ngals = np.array([pz.gals_per_steradian for pz in pzs])
return 1.0 / ngals