From ed8cf8e532cafd69ceceba6600b6860dd3cedf6b Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 18 Jul 2024 17:04:05 +0200 Subject: [PATCH] add lpt2 --- jaxpm/pm.py | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 2 deletions(-) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index bdbb9eb..ef4e887 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -7,7 +7,8 @@ from jax.sharding import PartitionSpec as P from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d, normal_field) -from jaxpm.growth import dGfa, growth_factor, growth_rate +from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second, + growth_rate, growth_rate_second) from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, longrange_kernel) from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx @@ -39,11 +40,46 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): return forces +def lpt2_source(mesh_size, initial_conditions): + + kvec = fftk(mesh_size) + # TODO : this has already been done for LPT1, we should reuse it + delta_k = fft3d(initial_conditions) + + source = jnp.zeros_like(delta_k) + + D1 = [1, 2, 0] + D2 = [2, 0, 1] + + # laplace_kernel should be actually inv laplace_kernel + # adding a minus sign here that will be negated when computing forces + # because F = -grad(phi) + # and phi = -laplace_kernel(delta_k) + pot_k = delta_k * laplace_kernel(delta_k) + + nabla_i_nabla_i = [ + ifft3d(gradient_kernel(kvec, i)**2 * pot_k) for i in range(3) + ] + # for diagonal terms + source += nabla_i_nabla_i[D1[0]] * nabla_i_nabla_i[D2[0]] + source += nabla_i_nabla_i[D1[1]] * nabla_i_nabla_i[D2[1]] + source += nabla_i_nabla_i[D1[2]] * nabla_i_nabla_i[D2[2]] + + # off diag terms + for i in range(3): + nabla_i_nabla_j = gradient_kernel(kvec, D1[i]) * gradient_kernel( + kvec, D2[i]) + phi = ifft3d(nabla_i_nabla_j * pot_k) + source -= phi**2 + + return source + + def lpt(cosmo, initial_conditions, a, halo_size=0): """ Computes first order LPT displacement """ - local_mesh_shape = get_local_shape(initial_conditions.shape) + (3, ) + local_mesh_shape = (*get_local_shape(initial_conditions.shape), 3) displacement = autoshmap( partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), in_specs=(), @@ -62,6 +98,39 @@ def lpt(cosmo, initial_conditions, a, halo_size=0): return dx, p, f +# @Credit Hugo Simon https://github.com/hsimonfroy/montecosmo +def lpt2(cosmo, initial_conditions, dx, p, f, a, halo_size=0): + + mesh_size = initial_conditions.shape + local_mesh_shape = (*get_local_shape(initial_conditions.shape), 3) + # TODO + # Displacements have been created in the previous step + # find a way to reuse them + displacement = autoshmap( + partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), + in_specs=(), + out_specs=P('x', 'y'))() # yapf: disable + + lpt2_delta = lpt2_source(mesh_size, initial_conditions) + delta2_k = fft3d(lpt2_delta) + + lpt2_forces = pm_forces(displacement, + mesh_size, + delta_k=delta2_k, + halo_size=halo_size) + dx2 = 3 / 7 * growth_factor_second(cosmo, a) * lpt2_forces + p2 = a**2 * growth_rate_second(cosmo, a) * jnp.sqrt( + jc.background.Esqr(cosmo, a)) * dx2 + f2 = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGf2a(cosmo, + a) * lpt2_forces + + dx += dx2 + p += p2 + f += f2 + + return dx, p, f + + def linear_field(mesh_shape, box_size, pk, seed): """ Generate initial conditions.