diff --git a/jaxpm/pm.py b/jaxpm/pm.py index f8059c7..262b916 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -17,7 +17,7 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, - paint_particles=False, + paint_absolute_pos=True, halo_size=0, sharding=None): """ @@ -28,7 +28,7 @@ def pm_forces(positions, "If mesh_shape is not provided, delta should be provided" mesh_shape = delta.shape - if paint_particles: + if paint_absolute_pos: paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding), x, halo_size=halo_size, @@ -72,14 +72,14 @@ def lpt(cosmo, Computes first and second order LPT displacement and momentum, e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258) """ - paint_particles = particles is not None + paint_absolute_pos = particles is not None a = jnp.atleast_1d(a) E = jnp.sqrt(jc.background.Esqr(cosmo, a)) delta_k = fft3d(initial_conditions) initial_force = pm_forces(particles, delta=delta_k, - paint_particles=paint_particles, + paint_absolute_pos=paint_absolute_pos, halo_size=halo_size, sharding=sharding) dx = growth_factor(cosmo, a) * initial_force @@ -111,7 +111,7 @@ def lpt(cosmo, delta_k2 = fft3d(delta2) init_force2 = pm_forces(particles, delta=delta_k2, - paint_particles=paint_particles, + paint_absolute_pos=paint_absolute_pos, halo_size=halo_size, sharding=sharding) # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second @@ -144,18 +144,20 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None): return field -def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None): +def make_ode_fn(mesh_shape, + paint_absolute_pos=True, + halo_size=0, + sharding=None): def nbody_ode(state, a, cosmo): """ state is a tuple (position, velocities) """ pos, vel = state - paint_particles = particles is not None forces = pm_forces(pos, mesh_shape=mesh_shape, - paint_particles=paint_particles, + paint_absolute_pos=paint_absolute_pos, halo_size=halo_size, sharding=sharding) * 1.5 * cosmo.Omega_m @@ -165,6 +167,8 @@ def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None): # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces + #dpos = dpos if not paint_absolute_pos else dpos + pos + return dpos, dvel return nbody_ode