This commit is contained in:
Wassim Kabalan 2024-12-08 22:45:09 +01:00
parent 7823fdaf98
commit af29c4005d
7 changed files with 68 additions and 63 deletions

View file

@ -1,9 +1,7 @@
import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.distributed import (fft3d, ifft3d,
normal_field)
from jaxpm.distributed import fft3d, ifft3d, normal_field
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
growth_rate, growth_rate_second)
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
@ -27,7 +25,8 @@ def pm_forces(positions,
mesh_shape = delta.shape
if paint_absolute_pos:
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape , device=sharding),
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
device=sharding),
pos,
halo_size=halo_size,
sharding=sharding)
@ -72,7 +71,8 @@ def lpt(cosmo,
"""
paint_absolute_pos = particles is not None
if particles is None:
particles = jnp.zeros_like(initial_conditions , shape=(*initial_conditions.shape , 3))
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -172,10 +172,11 @@ def make_ode_fn(mesh_shape,
return nbody_ode
def make_diffrax_ode(cosmo, mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(a, state, args):
"""
@ -199,6 +200,7 @@ def make_diffrax_ode(cosmo, mesh_shape,
return nbody_ode
def pgd_correction(pos, mesh_shape, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,