From d4d0a03c79f6d4b7885fa956cd4e7f23eb2450e7 Mon Sep 17 00:00:00 2001 From: EiffL Date: Fri, 27 Aug 2021 00:52:33 +0200 Subject: [PATCH 1/4] Adds test scrript --- dev/test_script.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 dev/test_script.py diff --git a/dev/test_script.py b/dev/test_script.py new file mode 100644 index 0000000..a9566c2 --- /dev/null +++ b/dev/test_script.py @@ -0,0 +1,63 @@ +# Start this script with: +# mpirun -np 4 python test_script.py +import os +os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4' +import matplotlib.pylab as plt +import jax +import numpy as np +import jax.numpy as jnp +import jax.lax as lax +from jax.experimental.maps import mesh, xmap +from jax.experimental.pjit import PartitionSpec, pjit +import tensorflow_probability as tfp; tfp = tfp.substrates.jax +tfd = tfp.distributions + +def cic_paint(mesh, positions): + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], + [0., 0, 1], [1., 1, 0], [1., 0, 1], + [0., 1, 1], [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1, 2), + scatter_dims_to_operand_dims=(0, 1, 2)) + mesh = lax.scatter_add(mesh, + neighboor_coords.reshape([-1,8,3]).astype('int32'), + kernel.reshape([-1,8]), + dnums) + return mesh + +# And let's draw some points from some 3D distribution +dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.) +pos = dist.sample(1e4, seed=jax.random.PRNGKey(0)) + +f = pjit(lambda x: cic_paint(x, pos), + in_axis_resources=PartitionSpec('x', 'y', 'z'), + out_axis_resources=None) + +devices = np.array(jax.devices()).reshape((2, 2, 1)) + +# Let's import the mesh +m = jnp.zeros([32, 32, 32]) + +with mesh(devices, ('x', 'y', 'z')): + # Shard the mesh, I'm not sure this is absolutely necessary + m = pjit(lambda x: x, + in_axis_resources=None, + out_axis_resources=PartitionSpec('x', 'y', 'z'))(m) + + # Apply the sharded CiC function + res = f(m) + +plt.imshow(res.sum(axis=2)) +plt.show() \ No newline at end of file From 350733966331b4df988e873c43132871e995201e Mon Sep 17 00:00:00 2001 From: EiffL Date: Sun, 13 Feb 2022 21:36:03 +0100 Subject: [PATCH 2/4] Adds a trivial jaxpm implementation --- jaxpm/__init__.py | 0 jaxpm/kernels.py | 85 +++++++++++++++++++++++++++++++++++++++++++++++ jaxpm/painting.py | 51 ++++++++++++++++++++++++++++ setup.py | 11 ++++++ 4 files changed, 147 insertions(+) create mode 100644 jaxpm/__init__.py create mode 100644 jaxpm/kernels.py create mode 100644 jaxpm/painting.py create mode 100644 setup.py diff --git a/jaxpm/__init__.py b/jaxpm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py new file mode 100644 index 0000000..73a8c93 --- /dev/null +++ b/jaxpm/kernels.py @@ -0,0 +1,85 @@ +import numpy as np +import jax.numpy as jnp + +def fftk(shape, symmetric=True, finite=False, dtype=np.float32): + """ Return k_vector given a shape (nc, nc, nc) and box_size + """ + k = [] + for d in range(len(shape)): + kd = np.fft.fftfreq(shape[d]) + kd *= 2 * np.pi + kdshape = np.ones(len(shape), dtype='int') + if symmetric and d == len(shape) - 1: + kd = kd[:shape[d] // 2 + 1] + kdshape[d] = len(kd) + kd = kd.reshape(kdshape) + + k.append(kd.astype(dtype)) + del kd, kdshape + return k + +def gradient_kernel(kvec, direction, order=1): + """ + Computes the gradient kernel in the requested direction + Parameters: + ----------- + kvec: array + Array of k values in Fourier space + direction: int + Index of the direction in which to take the gradient + Returns: + -------- + wts: array + Complex kernel + """ + if order == 0: + wts = 1j * kvec[direction] + wts = jnp.squeeze(wts) + wts[len(wts) // 2] = 0 + wts = wts.reshape(kvec[direction].shape) + return wts + else: + w = kvec[direction] + a = 1 / 6.0 * (8 * jnp.sin(w) - jnp.sin(2 * w)) + wts = a * 1j + return wts + +def laplace_kernel(kvec): + """ + Compute the Laplace kernel from a given K vector + Parameters: + ----------- + kvec: array + Array of k values in Fourier space + Returns: + -------- + wts: array + Complex kernel + """ + kk = sum(ki**2 for ki in kvec) + mask = (kk == 0).nonzero() + kk[mask] = 1 + wts = 1. / kk + imask = (~(kk == 0)).astype(int) + wts *= imask + return wts + +def longrange_kernel(kvec, r_split): + """ + Computes a long range kernel + Parameters: + ----------- + kvec: array + Array of k values in Fourier space + r_split: float + TODO: @modichirag add documentation + Returns: + -------- + wts: array + kernel + """ + if r_split != 0: + kk = sum(ki**2 for ki in kvec) + return np.exp(-kk * r_split**2) + else: + return 1. diff --git a/jaxpm/painting.py b/jaxpm/painting.py new file mode 100644 index 0000000..6eb1925 --- /dev/null +++ b/jaxpm/painting.py @@ -0,0 +1,51 @@ +import jax +import jax.numpy as jnp +import jax.lax as lax + +def cic_paint(mesh, positions): + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], + [0., 0, 1], [1., 1, 0], [1., 0, 1], + [0., 1, 1], [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + neighboor_coords = jnp.mod(neighboor_coords.reshape([-1,8,3]).astype('int32'), jnp.array(mesh.shape)) + + dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0, 1, 2), + scatter_dims_to_operand_dims=(0, 1, 2)) + mesh = lax.scatter_add(mesh, + neighboor_coords, + kernel.reshape([-1,8]), + dnums) + return mesh + +def cic_read(mesh, positions): + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ + positions = jnp.expand_dims(positions, 1) + floor = jnp.floor(positions) + connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], + [0., 0, 1], [1., 1, 0], [1., 0, 1], + [0., 1, 1], [1., 1, 1]]]) + + neighboor_coords = floor + connection + kernel = 1. - jnp.abs(positions - neighboor_coords) + kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] + + neighboor_coords = jnp.mod(neighboor_coords.astype('int32'), jnp.array(mesh.shape)) + + return (mesh[neighboor_coords[...,0], + neighboor_coords[...,1], + neighboor_coords[...,3]]*kernel).sum(axis=-1) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..44be5a1 --- /dev/null +++ b/setup.py @@ -0,0 +1,11 @@ +from setuptools import setup, find_packages + +setup( + name='JaxPM', + version='0.0.1', + url='https://github.com/DifferentiableUniverseInitiative/JaxPM', + author='JaxPM developers', + description='A dead simple FastPM implementation in JAX', + packages=find_packages(), + install_requires=['jax', 'jax_cosmo'], +) \ No newline at end of file From faec622152296134881f638fc70276bf3dcee0b5 Mon Sep 17 00:00:00 2001 From: EiffL Date: Mon, 14 Feb 2022 00:37:35 +0100 Subject: [PATCH 3/4] adds growth functions from Chirag --- jaxpm/growth.py | 608 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 608 insertions(+) create mode 100644 jaxpm/growth.py diff --git a/jaxpm/growth.py b/jaxpm/growth.py new file mode 100644 index 0000000..0be4718 --- /dev/null +++ b/jaxpm/growth.py @@ -0,0 +1,608 @@ +import jax.numpy as np + +from jax_cosmo.scipy.interpolate import interp +from jax_cosmo.scipy.ode import odeint +from jax_cosmo.background import * + +def E(cosmo, a): + r"""Scale factor dependent factor E(a) in the Hubble + parameter. + Parameters + ---------- + a : array_like + Scale factor + Returns + ------- + E : ndarray, or float if input scalar + Square of the scaling of the Hubble constant as a function of + scale factor + Notes + ----- + The Hubble parameter at scale factor `a` is given by + :math:`H^2(a) = E^2(a) H_o^2` where :math:`E^2` is obtained through + Friedman's Equation (see :cite:`2005:Percival`) : + .. math:: + E^2(a) = \Omega_m a^{-3} + \Omega_k a^{-2} + \Omega_{de} a^{f(a)} + where :math:`f(a)` is the Dark Energy evolution parameter computed + by :py:meth:`.f_de`. + """ + return np.power(Esqr(cosmo, a), 0.5) + + +def df_de(cosmo, a, epsilon=1e-5): + r"""Derivative of the evolution parameter for the Dark Energy density + f(a) with respect to the scale factor. + Parameters + ---------- + cosmo: Cosmology + Cosmological parameters structure + a : array_like + Scale factor + epsilon: float value + Small number to make sure we are not dividing by 0 and avoid a singularity + Returns + ------- + df(a)/da : ndarray, or float if input scalar + Derivative of the evolution parameter for the Dark Energy density + with respect to the scale factor. + Notes + ----- + The expression for :math:`\frac{df(a)}{da}` is: + .. math:: + \frac{df}{da}(a) = =\frac{3w_a \left( \ln(a-\epsilon)- + \frac{a-1}{a-\epsilon}\right)}{\ln^2(a-\epsilon)} + """ + return ( + 3 + * cosmo.wa + * (np.log(a - epsilon) - (a - 1) / (a - epsilon)) + / np.power(np.log(a - epsilon), 2) + ) + + +def dEa(cosmo, a): + r"""Derivative of the scale factor dependent factor E(a) in the Hubble + parameter. + Parameters + ---------- + a : array_like + Scale factor + Returns + ------- + dE(a)/da : ndarray, or float if input scalar + Derivative of the scale factor dependent factor in the Hubble + parameter with respect to the scale factor. + Notes + ----- + The expression for :math:`\frac{dE}{da}` is: + .. math:: + \frac{dE(a)}{da}=\frac{-3a^{-4}\Omega_{0m} + -2a^{-3}\Omega_{0k} + +f'_{de}\Omega_{0de}a^{f_{de}(a)}}{2E(a)} + Notes + ----- + The Hubble parameter at scale factor `a` is given by + :math:`H^2(a) = E^2(a) H_o^2` where :math:`E^2` is obtained through + Friedman's Equation (see :cite:`2005:Percival`) : + .. math:: + E^2(a) = \Omega_m a^{-3} + \Omega_k a^{-2} + \Omega_{de} a^{f(a)} + where :math:`f(a)` is the Dark Energy evolution parameter computed + by :py:meth:`.f_de`. + """ + return ( + 0.5 + * ( + -3 * cosmo.Omega_m * np.power(a, -4) + - 2 * cosmo.Omega_k * np.power(a, -3) + + df_de(cosmo, a) * cosmo.Omega_de * np.power(a, f_de(cosmo, a)) + ) + / np.power(Esqr(cosmo, a), 0.5) + ) + + +def growth_factor(cosmo, a): + """Compute linear growth factor D(a) at a given scale factor, + normalized such that D(a=1) = 1. + + Parameters + ---------- + cosmo: `Cosmology` + Cosmology object + + a: array_like + Scale factor + + Returns + ------- + D: ndarray, or float if input scalar + Growth factor computed at requested scale factor + + Notes + ----- + The growth computation will depend on the cosmology parametrization, for + instance if the $\gamma$ parameter is defined, the growth will be computed + assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for + growth will be solved. + """ + if cosmo._flags["gamma_growth"]: + return _growth_factor_gamma(cosmo, a) + else: + return _growth_factor_ODE(cosmo, a) + + +def growth_factor_second(cosmo, a): + """Compute second order growth factor D2(a) at a given scale factor, + normalized such that D(a=1) = 1. + + Parameters + ---------- + cosmo: `Cosmology` + Cosmology object + + a: array_like + Scale factor + + Returns + ------- + D2: ndarray, or float if input scalar + Growth factor computed at requested scale factor + + Notes + ----- + The growth computation will depend on the cosmology parametrization, + as for the linear growth. Currently the second order growth + factor is not implemented with $\gamma$ parameter. + """ + if cosmo._flags["gamma_growth"]: + raise NotImplementedError( + "Gamma growth rate is not implemented for second order growth!" + ) + return None + else: + return _growth_factor_second_ODE(cosmo, a) + + +def growth_rate(cosmo, a): + """Compute growth rate dD/dlna at a given scale factor. + + Parameters + ---------- + cosmo: `Cosmology` + Cosmology object + + a: array_like + Scale factor + + Returns + ------- + f: ndarray, or float if input scalar + Growth rate computed at requested scale factor + + Notes + ----- + The growth computation will depend on the cosmology parametrization, for + instance if the $\gamma$ parameter is defined, the growth will be computed + assuming the $f = \Omega^\gamma$ growth rate, otherwise the usual ODE for + growth will be solved. + + The LCDM approximation to the growth rate :math:`f_{\gamma}(a)` is given by: + + .. math:: + + f_{\gamma}(a) = \Omega_m^{\gamma} (a) + + with :math: `\gamma` in LCDM, given approximately by: + .. math:: + + \gamma = 0.55 + + see :cite:`2019:Euclid Preparation VII, eqn.32` + """ + if cosmo._flags["gamma_growth"]: + return _growth_rate_gamma(cosmo, a) + else: + return _growth_rate_ODE(cosmo, a) + + +def growth_rate_second(cosmo, a): + """Compute second order growth rate dD2/dlna at a given scale factor. + + Parameters + ---------- + cosmo: `Cosmology` + Cosmology object + + a: array_like + Scale factor + + Returns + ------- + f2: ndarray, or float if input scalar + Second order growth rate computed at requested scale factor + + Notes + ----- + The growth computation will depend on the cosmology parametrization, + as for the linear growth rate. Currently the second order growth + rate is not implemented with $\gamma$ parameter. + """ + if cosmo._flags["gamma_growth"]: + raise NotImplementedError( + "Gamma growth factor is not implemented for second order growth!" + ) + return None + else: + return _growth_rate_second_ODE(cosmo, a) + + +def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=128, eps=1e-4): + """Compute linear growth factor D(a) at a given scale factor, + normalised such that D(a=1) = 1. + + Parameters + ---------- + a: array_like + Scale factor + + amin: float + Mininum scale factor, default 1e-3 + + Returns + ------- + D: ndarray, or float if input scalar + Growth factor computed at requested scale factor + """ + # Check if growth has already been computed + if not "background.growth_factor" in cosmo._workspace.keys(): + # Compute tabulated array + atab = np.logspace(log10_amin, 0.0, steps) + + def D_derivs(y, x): + q = ( + 2.0 + - 0.5 + * ( + Omega_m_a(cosmo, x) + + (1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x) + ) + ) / x + r = 1.5 * Omega_m_a(cosmo, x) / x / x + + g1, g2 = y[0] + f1, f2 = y[1] + dy1da = [f1, -q * f1 + r * g1] + dy2da = [f2, -q * f2 + r * g2 - r * g1 ** 2] + return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]]) + + y0 = np.array([[atab[0], -3.0 / 7 * atab[0] ** 2], [1.0, -6.0 / 7 * atab[0]]]) + y = odeint(D_derivs, y0, atab) + + # compute second order derivatives growth + dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab) + dyda2 = np.transpose(dyda2, (2, 0, 1)) + + # Normalize results + y1 = y[:, 0, 0] + gtab = y1 / y1[-1] + y2 = y[:, 0, 1] + g2tab = y2 / y2[-1] + # To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da + ftab = y[:, 1, 0] / y1[-1] * atab / gtab + f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab + # Similarly for second order derivatives + # Note: these factors are not accessible as parent functions yet + # since it is unclear what to refer to them with. + htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab + h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab + + cache = { + "a": atab, + "g": gtab, + "f": ftab, + "h": htab, + "g2": g2tab, + "f2": f2tab, + "h2": h2tab, + } + cosmo._workspace["background.growth_factor"] = cache + else: + cache = cosmo._workspace["background.growth_factor"] + return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) + + +def _growth_rate_ODE(cosmo, a): + """Compute growth rate dD/dlna at a given scale factor by solving the linear + growth ODE. + + Parameters + ---------- + cosmo: `Cosmology` + Cosmology object + + a: array_like + Scale factor + + Returns + ------- + f: ndarray, or float if input scalar + Growth rate computed at requested scale factor + """ + # Check if growth has already been computed, if not, compute it + if not "background.growth_factor" in cosmo._workspace.keys(): + _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) + cache = cosmo._workspace["background.growth_factor"] + return interp(a, cache["a"], cache["f"]) + + +def _growth_factor_second_ODE(cosmo, a): + """Compute second order growth factor D2(a) at a given scale factor, + normalised such that D(a=1) = 1. + + Parameters + ---------- + a: array_like + Scale factor + + amin: float + Mininum scale factor, default 1e-3 + + Returns + ------- + D2: ndarray, or float if input scalar + Second order growth factor computed at requested scale factor + """ + # Check if growth has already been computed, if not, compute it + if not "background.growth_factor" in cosmo._workspace.keys(): + _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) + cache = cosmo._workspace["background.growth_factor"] + return interp(a, cache["a"], cache["g2"]) + + +def _growth_rate_ODE(cosmo, a): + """Compute growth rate dD/dlna at a given scale factor by solving the linear + growth ODE. + + Parameters + ---------- + cosmo: `Cosmology` + Cosmology object + + a: array_like + Scale factor + + Returns + ------- + f: ndarray, or float if input scalar + Second order growth rate computed at requested scale factor + """ + # Check if growth has already been computed, if not, compute it + if not "background.growth_factor" in cosmo._workspace.keys(): + _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) + cache = cosmo._workspace["background.growth_factor"] + return interp(a, cache["a"], cache["f"]) + + +def _growth_rate_second_ODE(cosmo, a): + """Compute second order growth rate dD2/dlna at a given scale factor by solving the linear + growth ODE. + + Parameters + ---------- + cosmo: `Cosmology` + Cosmology object + + a: array_like + Scale factor + + Returns + ------- + f2: ndarray, or float if input scalar + Second order growth rate computed at requested scale factor + """ + # Check if growth has already been computed, if not, compute it + if not "background.growth_factor" in cosmo._workspace.keys(): + _growth_factor_ODE(cosmo, np.atleast_1d(1.0)) + cache = cosmo._workspace["background.growth_factor"] + return interp(a, cache["a"], cache["f2"]) + + +def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128): + r"""Computes growth factor by integrating the growth rate provided by the + \gamma parametrization. Normalized such that D( a=1) =1 + + Parameters + ---------- + a: array_like + Scale factor + + amin: float + Mininum scale factor, default 1e-3 + + Returns + ------- + D: ndarray, or float if input scalar + Growth factor computed at requested scale factor + + """ + # Check if growth has already been computed, if not, compute it + if not "background.growth_factor" in cosmo._workspace.keys(): + # Compute tabulated array + atab = np.logspace(log10_amin, 0.0, steps) + + def integrand(y, loga): + xa = np.exp(loga) + return _growth_rate_gamma(cosmo, xa) + + gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab))) + gtab = gtab / gtab[-1] # Normalize to a=1. + cache = {"a": atab, "g": gtab} + cosmo._workspace["background.growth_factor"] = cache + else: + cache = cosmo._workspace["background.growth_factor"] + return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) + + +def _growth_rate_gamma(cosmo, a): + r"""Growth rate approximation at scale factor `a`. + + Parameters + ---------- + cosmo: `Cosmology` + Cosmology object + + a : array_like + Scale factor + + Returns + ------- + f_gamma : ndarray, or float if input scalar + Growth rate approximation at the requested scale factor + + Notes + ----- + The LCDM approximation to the growth rate :math:`f_{\gamma}(a)` is given by: + + .. math:: + + f_{\gamma}(a) = \Omega_m^{\gamma} (a) + + with :math: `\gamma` in LCDM, given approximately by: + .. math:: + + \gamma = 0.55 + + see :cite:`2019:Euclid Preparation VII, eqn.32` + """ + return Omega_m_a(cosmo, a) ** cosmo.gamma + + + +def Gf(cosmo, a): + r""" + FastPM growth factor function + + Parameters + ---------- + cosmo: dict + Cosmology dictionary. + + a : array_like + Scale factor. + + Returns + ------- + Scalar float Tensor : FastPM growth factor function. + + Notes + ----- + + The expression for :math:`Gf(a)` is: + + .. math:: + Gf(a)=D'_{1norm}*a**3*E(a) + """ + f1 = growth_rate(cosmo, a) + g1 = growth_factor(cosmo, a) + D1f = f1*g1/ a + return D1f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) + + +def Gf2(cosmo, a): + r""" FastPM second order growth factor function + + Parameters + ---------- + cosmo: dict + Cosmology dictionary. + + a : array_like + Scale factor. + + Returns + ------- + Scalar float Tensor : FastPM second order growth factor function. + + Notes + ----- + + The expression for :math:`Gf_2(a)` is: + + .. math:: + Gf_2(a)=D'_{2norm}*a**3*E(a) + """ + f2 = growth_rate_second(cosmo, a) + g2 = growth_factor_second(cosmo, a) + D2f = f2*g2/ a + return D2f * np.power(a, 3) * np.power(Esqr(cosmo, a), 0.5) + + +def dGfa(cosmo, a): + r""" Derivative of Gf against a + + Parameters + ---------- + cosmo: dict + Cosmology dictionary. + + a : array_like + Scale factor. + + Returns + ------- + Scalar float Tensor : the derivative of Gf against a. + + Notes + ----- + + The expression for :math:`gf(a)` is: + + .. math:: + gf(a)=\frac{dGF}{da}= D^{''}_1 * a ** 3 *E(a) +D'_{1norm}*a ** 3 * E'(a) + + 3 * a ** 2 * E(a)*D'_{1norm} + + """ + f1 = growth_rate(cosmo, a) + g1 = growth_factor(cosmo, a) + D1f = f1*g1/ a + cache = cosmo._workspace['background.growth_factor'] + f1p = cache['h'] / cache['a'] * cache['g'] + f1p = interp(np.log(a), np.log(cache['a']), f1p) + Ea = E(cosmo, a) + return (f1p * a**3 * Ea + D1f * a**3 * dEa(cosmo, a) + + 3 * a**2 * Ea * D1f) + + +def dGf2a(cosmo, a): + r""" Derivative of Gf2 against a + + Parameters + ---------- + cosmo: dict + Cosmology dictionary. + + a : array_like + Scale factor. + + Returns + ------- + Scalar float Tensor : the derivative of Gf2 against a. + + Notes + ----- + + The expression for :math:`gf2(a)` is: + + .. math:: + gf_2(a)=\frac{dGF_2}{da}= D^{''}_2 * a ** 3 *E(a) +D'_{2norm}*a ** 3 * E'(a) + + 3 * a ** 2 * E(a)*D'_{2norm} + + """ + f2 = growth_rate_second(cosmo, a) + g2 = growth_factor_second(cosmo, a) + D2f = f2*g2/ a + cache = cosmo._workspace['background.growth_factor'] + f2p = cache['h2'] / cache['a'] * cache['g2'] + f2p = interp(np.log(a), np.log(cache['a']), f2p) + E = E(cosmo, a) + return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + + 3 * a**2 * E * D2f) \ No newline at end of file From 8543246f62ada68fdbaf702d70d148d40484eb46 Mon Sep 17 00:00:00 2001 From: EiffL Date: Mon, 14 Feb 2022 01:59:12 +0100 Subject: [PATCH 4/4] Adds demo and notebooks --- jaxpm/pm.py | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 jaxpm/pm.py diff --git a/jaxpm/pm.py b/jaxpm/pm.py new file mode 100644 index 0000000..f4b405e --- /dev/null +++ b/jaxpm/pm.py @@ -0,0 +1,72 @@ +import jax +import jax.numpy as jnp + +import jax_cosmo as jc + +from jaxpm.kernels import fftk, gradient_kernel, laplace_kernel, longrange_kernel +from jaxpm.painting import cic_paint, cic_read +from jaxpm.growth import growth_factor, growth_rate, dGfa + +def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): + """ + Computes gravitational forces on particles using a PM scheme + """ + if mesh_shape is None: + mesh_shape = delta.shape + kvec = fftk(mesh_shape) + + if delta is None: + delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions)) + else: + delta_k = jnp.fft.rfftn(delta) + + # Computes gravitational potential + pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) + # Computes gravitational forces + return jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), positions) + for i in range(3)],axis=-1) + + +def lpt(cosmo, initial_conditions, positions, a): + """ + Computes first order LPT displacement + """ + initial_force = pm_forces(positions, delta=initial_conditions) + a = jnp.atleast_1d(a) + dx = growth_factor(cosmo, a) * initial_force + p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx + f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force + return dx, p, f + +def linear_field(mesh_shape, box_size, pk, seed): + """ + Generate initial conditions. + """ + kvec = fftk(mesh_shape) + kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 + pkmesh = pk(kmesh) + + field = jax.random.normal(seed, mesh_shape) + field = jnp.fft.rfftn(field) * pkmesh**0.5 + field = jnp.fft.irfftn(field) + return field + +def make_ode_fn(mesh_shape): + + def nbody_ode(state, a, cosmo): + """ + state is a tuple (position, velocities) + """ + pos, vel = state + + forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m + + # Computes the update of position (drift) + dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel + + # Computes the update of velocity (kick) + dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces + + return dpos, dvel + + return nbody_ode