From 9b21eed3b5c8d7a3ccaa267dc546c01f239747cd Mon Sep 17 00:00:00 2001 From: Hugo Simonfroy Date: Wed, 31 Jul 2024 00:46:53 +0200 Subject: [PATCH 1/2] 2lpt, get_ode, invlaplace, docstrings --- jaxpm/kernels.py | 143 +++++++++++++++++++++++++--------------------- jaxpm/painting.py | 62 ++++++++++++++------ jaxpm/pm.py | 115 +++++++++++++++++++++++++++---------- 3 files changed, 206 insertions(+), 114 deletions(-) diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 8447f8a..3bcb9ee 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -3,8 +3,9 @@ import numpy as np def fftk(shape, symmetric=True, finite=False, dtype=np.float32): - """ Return k_vector given a shape (nc, nc, nc) and box_size - """ + """ + Return wave-vectors for a given shape + """ k = [] for d in range(len(shape)): kd = np.fft.fftfreq(shape[d]) @@ -22,18 +23,20 @@ def fftk(shape, symmetric=True, finite=False, dtype=np.float32): def gradient_kernel(kvec, direction, order=1): """ - Computes the gradient kernel in the requested direction - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - direction: int - Index of the direction in which to take the gradient - Returns: - -------- - wts: array - Complex kernel - """ + Computes the gradient kernel in the requested direction + + Parameters + ----------- + kvec: list + List of wave-vectors in Fourier space + direction: int + Index of the direction in which to take the gradient + + Returns + -------- + wts: array + Complex kernel values + """ if order == 0: wts = 1j * kvec[direction] wts = jnp.squeeze(wts) @@ -47,41 +50,43 @@ def gradient_kernel(kvec, direction, order=1): return wts -def laplace_kernel(kvec): +def invlaplace_kernel(kvec): + """ + Compute the inverse Laplace kernel + + Parameters + ----------- + kvec: list + List of wave-vectors + + Returns + -------- + wts: array + Complex kernel values """ - Compute the Laplace kernel from a given K vector - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - Returns: - -------- - wts: array - Complex kernel - """ kk = sum(ki**2 for ki in kvec) - mask = (kk == 0).nonzero() - kk[mask] = 1 - wts = 1. / kk - imask = (~(kk == 0)).astype(int) - wts *= imask - return wts + kk_nozeros = jnp.where(kk==0, 1, kk) + return - jnp.where(kk==0, 0, 1 / kk_nozeros) def longrange_kernel(kvec, r_split): """ - Computes a long range kernel - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - r_split: float + Computes a long range kernel + + Parameters + ----------- + kvec: list + List of wave-vectors + r_split: float + Splitting radius + + Returns + -------- + wts: array + Complex kernel values + TODO: @modichirag add documentation - Returns: - -------- - wts: array - kernel - """ + """ if r_split != 0: kk = sum(ki**2 for ki in kvec) return np.exp(-kk * r_split**2) @@ -91,15 +96,21 @@ def longrange_kernel(kvec, r_split): def cic_compensation(kvec): """ - Computes cic compensation kernel. - Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499 - Itself based on equation 18 (with p=2) of - `Jing et al 2005 `_ - Args: - kvec: array of k values in Fourier space - Returns: - v: array of kernel - """ + Computes cic compensation kernel. + Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499 + Itself based on equation 18 (with p=2) of + [Jing et al 2005](https://arxiv.org/abs/astro-ph/0409240) + + Parameters: + ----------- + kvec: list + List of wave-vectors + + Returns: + -------- + wts: array + Complex kernel values + """ kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] wts = (kwts[0] * kwts[1] * kwts[2])**(-2) return wts @@ -107,20 +118,22 @@ def cic_compensation(kvec): def PGD_kernel(kvec, kl, ks): """ - Computes the PGD kernel - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - kl: float - initial long range scale parameter - ks: float - initial dhort range scale parameter - Returns: - -------- - v: array - kernel - """ + Computes the PGD kernel + + Parameters: + ----------- + kvec: list + List of wave-vectors + kl: float + Initial long range scale parameter + ks: float + Initial dhort range scale parameter + + Returns: + -------- + v: array + Complex kernel values + """ kk = sum(ki**2 for ki in kvec) kl2 = kl**2 ks4 = ks**4 diff --git a/jaxpm/painting.py b/jaxpm/painting.py index fb5dbd5..7b46949 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -6,10 +6,18 @@ from jaxpm.kernels import cic_compensation, fftk def cic_paint(mesh, positions, weight=None): - """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ + """ + Paint positions onto mesh + + Parameters: + ----------- + mesh: [nx, ny, nz] + positions: [npart, 3] + + Returns: + -------- + mesh: [nx, ny, nz] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -35,10 +43,18 @@ def cic_paint(mesh, positions, weight=None): def cic_read(mesh, positions): - """ Paints positions onto mesh - mesh: [nx, ny, nz] - positions: [npart, 3] - """ + """ + Read mesh at positions + + Parameters: + ----------- + mesh: [nx, ny, nz] + positions: [npart, 3] + + Returns: + -------- + values: [npart] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -56,11 +72,19 @@ def cic_read(mesh, positions): def cic_paint_2d(mesh, positions, weight): - """ Paints positions onto a 2d mesh - mesh: [nx, ny] - positions: [npart, 2] - weight: [npart] - """ + """ + Paints positions onto 2d mesh + + Parameters: + ----------- + mesh: [nx, ny] + positions: [npart, 2] + weight: [npart] + + Returns: + -------- + mesh: [nx, ny] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) @@ -86,12 +110,12 @@ def cic_paint_2d(mesh, positions, weight): def compensate_cic(field): """ - Compensate for CiC painting - Args: - field: input 3D cic-painted field - Returns: - compensated_field - """ + Compensate for CiC painting + Args: + field: input 3D cic-painted field + Returns: + compensated_field + """ nc = field.shape kvec = fftk(nc) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 9b14a87..4aedef5 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,48 +1,80 @@ import jax import jax.numpy as jnp import jax_cosmo as jc +from jax_cosmo import Cosmology -from jaxpm.growth import dGfa, growth_factor, growth_rate -from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, - longrange_kernel) +from jaxpm.growth import growth_factor, growth_rate, dGfa, growth_factor_second, growth_rate_second, dGf2a +from jaxpm.kernels import PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel from jaxpm.painting import cic_paint, cic_read -def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): + +def pm_forces(positions, mesh_shape, delta=None, r_split=0): """ Computes gravitational forces on particles using a PM scheme """ - if mesh_shape is None: - mesh_shape = delta.shape - kvec = fftk(mesh_shape) - if delta is None: delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions)) - else: + elif jnp.isrealobj(delta): delta_k = jnp.fft.rfftn(delta) + else: + delta_k = delta # Computes gravitational potential - pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, - r_split=r_split) + kvec = fftk(mesh_shape) + pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) # Computes gravitational forces - return jnp.stack([ - cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i) * pot_k), positions) - for i in range(3) - ], - axis=-1) + return jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i) * pot_k), positions) + for i in range(3)], axis=-1) -def lpt(cosmo, initial_conditions, positions, a): +def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1): """ - Computes first order LPT displacement + Computes first and second order LPT displacement and momentum, + e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258) """ - initial_force = pm_forces(positions, delta=initial_conditions) a = jnp.atleast_1d(a) - dx = growth_factor(cosmo, a) * initial_force - p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, - a)) * dx - f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, - a) * initial_force + E = jnp.sqrt(jc.background.Esqr(cosmo, a)) + delta_k = jnp.fft.rfftn(init_mesh) # TODO: pass the modes directly to save one or two fft? + mesh_shape = init_mesh.shape + + init_force = pm_forces(positions, mesh_shape, delta=delta_k) + dx = growth_factor(cosmo, a) * init_force + p = a**2 * growth_rate(cosmo, a) * E * dx + f = a**2 * E * dGfa(cosmo, a) * init_force + + if order == 2: + kvec = fftk(mesh_shape) + pot_k = delta_k * invlaplace_kernel(kvec) + + delta2 = 0 + shear_acc = 0 + # for i, ki in enumerate(kvec): + for i in range(3): + # Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)... + # shear_ii = jnp.fft.irfftn(- ki**2 * pot_k) + nabla_i_nabla_i = gradient_kernel(kvec, i)**2 + shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k) + delta2 += shear_ii * shear_acc + shear_acc += shear_ii + + # for kj in kvec[i+1:]: + for j in range(i+1, 3): + # Substract squared strict-up-triangle terms + # delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2 + nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j) + delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2 + + init_force2 = pm_forces(positions, mesh_shape, delta=jnp.fft.rfftn(delta2)) + # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second + dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2 + p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2 + f2 = a**2 * E * dGf2a(cosmo, a) * init_force2 + + dx += dx2 + p += p2 + f += f2 + return dx, p, f @@ -82,10 +114,33 @@ def make_ode_fn(mesh_shape): return nbody_ode +def get_ode_fn(cosmo:Cosmology, mesh_shape): + + def nbody_ode(a, state, args): + """ + State is an array [position, velocities] + + Compatible with [Diffrax API](https://docs.kidger.site/diffrax/) + """ + pos, vel = state + forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m + + # Computes the update of position (drift) + dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel + + # Computes the update of velocity (kick) + dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces + + return jnp.stack([dpos, dvel]) + + return nbody_ode + def pgd_correction(pos, mesh_shape, params): """ - improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 + improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, + based on https://arxiv.org/abs/1804.00671 + args: pos: particle positions [npart, 3] params: [alpha, kl, ks] pgd parameters @@ -96,9 +151,9 @@ def pgd_correction(pos, mesh_shape, params): delta_k = jnp.fft.rfftn(delta) PGD_range=PGD_kernel(kvec, kl, ks) - pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range + pot_k_pgd=(delta_k * invlaplace_kernel(kvec))*PGD_range - forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos) + forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos) for i in range(3)],axis=-1) dpos_pgd = forces_pgd*alpha @@ -107,7 +162,7 @@ def pgd_correction(pos, mesh_shape, params): def make_neural_ode_fn(model, mesh_shape): - def neural_nbody_ode(state, a, cosmo, params): + def neural_nbody_ode(state, a, cosmo:Cosmology, params): """ state is a tuple (position, velocities) """ @@ -119,14 +174,14 @@ def make_neural_ode_fn(model, mesh_shape): delta_k = jnp.fft.rfftn(delta) # Computes gravitational potential - pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) + pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0) # Apply a correction filter kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec)) pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) # Computes gravitational forces - forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos) + forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos) for i in range(3)],axis=-1) forces = forces * 1.5 * cosmo.Omega_m From 30060e82ea7e8249f772531d5281dabaa6fca872 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 5 Aug 2024 19:37:33 +0200 Subject: [PATCH 2/2] roll back painting --- jaxpm/painting.py | 62 +++++++++++++++-------------------------------- 1 file changed, 19 insertions(+), 43 deletions(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 7b46949..fb5dbd5 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -6,18 +6,10 @@ from jaxpm.kernels import cic_compensation, fftk def cic_paint(mesh, positions, weight=None): - """ - Paint positions onto mesh - - Parameters: - ----------- - mesh: [nx, ny, nz] - positions: [npart, 3] - - Returns: - -------- - mesh: [nx, ny, nz] - """ + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -43,18 +35,10 @@ def cic_paint(mesh, positions, weight=None): def cic_read(mesh, positions): - """ - Read mesh at positions - - Parameters: - ----------- - mesh: [nx, ny, nz] - positions: [npart, 3] - - Returns: - -------- - values: [npart] - """ + """ Paints positions onto mesh + mesh: [nx, ny, nz] + positions: [npart, 3] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], @@ -72,19 +56,11 @@ def cic_read(mesh, positions): def cic_paint_2d(mesh, positions, weight): - """ - Paints positions onto 2d mesh - - Parameters: - ----------- - mesh: [nx, ny] - positions: [npart, 2] - weight: [npart] - - Returns: - -------- - mesh: [nx, ny] - """ + """ Paints positions onto a 2d mesh + mesh: [nx, ny] + positions: [npart, 2] + weight: [npart] + """ positions = jnp.expand_dims(positions, 1) floor = jnp.floor(positions) connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]]) @@ -110,12 +86,12 @@ def cic_paint_2d(mesh, positions, weight): def compensate_cic(field): """ - Compensate for CiC painting - Args: - field: input 3D cic-painted field - Returns: - compensated_field - """ + Compensate for CiC painting + Args: + field: input 3D cic-painted field + Returns: + compensated_field + """ nc = field.shape kvec = fftk(nc)