By default use absoulute painting with

This commit is contained in:
Wassim Kabalan 2024-12-05 18:23:18 +01:00
parent c1b276d224
commit e0c118a540

View file

@ -17,7 +17,7 @@ def pm_forces(positions,
mesh_shape=None, mesh_shape=None,
delta=None, delta=None,
r_split=0, r_split=0,
paint_particles=False, paint_absolute_pos=True,
halo_size=0, halo_size=0,
sharding=None): sharding=None):
""" """
@ -28,7 +28,7 @@ def pm_forces(positions,
"If mesh_shape is not provided, delta should be provided" "If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape mesh_shape = delta.shape
if paint_particles: if paint_absolute_pos:
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding), paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
x, x,
halo_size=halo_size, halo_size=halo_size,
@ -72,14 +72,14 @@ def lpt(cosmo,
Computes first and second order LPT displacement and momentum, Computes first and second order LPT displacement and momentum,
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258) e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
""" """
paint_particles = particles is not None paint_absolute_pos = particles is not None
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a)) E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = fft3d(initial_conditions) delta_k = fft3d(initial_conditions)
initial_force = pm_forces(particles, initial_force = pm_forces(particles,
delta=delta_k, delta=delta_k,
paint_particles=paint_particles, paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force dx = growth_factor(cosmo, a) * initial_force
@ -111,7 +111,7 @@ def lpt(cosmo,
delta_k2 = fft3d(delta2) delta_k2 = fft3d(delta2)
init_force2 = pm_forces(particles, init_force2 = pm_forces(particles,
delta=delta_k2, delta=delta_k2,
paint_particles=paint_particles, paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
@ -144,18 +144,20 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
return field return field
def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None): def make_ode_fn(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(state, a, cosmo): def nbody_ode(state, a, cosmo):
""" """
state is a tuple (position, velocities) state is a tuple (position, velocities)
""" """
pos, vel = state pos, vel = state
paint_particles = particles is not None
forces = pm_forces(pos, forces = pm_forces(pos,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,
paint_particles=paint_particles, paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m sharding=sharding) * 1.5 * cosmo.Omega_m
@ -165,6 +167,8 @@ def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None):
# Computes the update of velocity (kick) # Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
#dpos = dpos if not paint_absolute_pos else dpos + pos
return dpos, dvel return dpos, dvel
return nbody_ode return nbody_ode