pm now accept pytrees

This commit is contained in:
Wassim Kabalan 2025-01-18 01:13:51 +01:00
parent 204a9526ec
commit 9e203b5680

View file

@ -7,7 +7,7 @@ from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
invlaplace_kernel, longrange_kernel)
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
import jax
def pm_forces(positions,
mesh_shape=None,
@ -51,9 +51,11 @@ def pm_forces(positions,
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
kvec, r_split=r_split)
# Computes gravitational forces
forces = jnp.stack([
forces = [
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
) for i in range(3)], axis=-1) # yapf: disable
) for i in range(3)]
forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2])
return forces
@ -71,8 +73,8 @@ def lpt(cosmo,
"""
paint_absolute_pos = particles is not None
if particles is None:
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
shape=(*ic.shape, 3)) , initial_conditions)
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -172,8 +174,7 @@ def make_ode_fn(mesh_shape,
return nbody_ode
def make_diffrax_ode(cosmo,
mesh_shape,
def make_diffrax_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
@ -183,6 +184,7 @@ def make_diffrax_ode(cosmo,
state is a tuple (position, velocities)
"""
pos, vel = state
cosmo = args
forces = pm_forces(pos,
mesh_shape=mesh_shape,
@ -196,7 +198,7 @@ def make_diffrax_ode(cosmo,
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return jnp.stack([dpos, dvel])
return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel)
return nbody_ode