mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-08 04:40:53 +00:00
By default use absoulute painting with
This commit is contained in:
parent
c1b276d224
commit
e0c118a540
1 changed files with 12 additions and 8 deletions
20
jaxpm/pm.py
20
jaxpm/pm.py
|
@ -17,7 +17,7 @@ def pm_forces(positions,
|
||||||
mesh_shape=None,
|
mesh_shape=None,
|
||||||
delta=None,
|
delta=None,
|
||||||
r_split=0,
|
r_split=0,
|
||||||
paint_particles=False,
|
paint_absolute_pos=True,
|
||||||
halo_size=0,
|
halo_size=0,
|
||||||
sharding=None):
|
sharding=None):
|
||||||
"""
|
"""
|
||||||
|
@ -28,7 +28,7 @@ def pm_forces(positions,
|
||||||
"If mesh_shape is not provided, delta should be provided"
|
"If mesh_shape is not provided, delta should be provided"
|
||||||
mesh_shape = delta.shape
|
mesh_shape = delta.shape
|
||||||
|
|
||||||
if paint_particles:
|
if paint_absolute_pos:
|
||||||
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
|
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
|
||||||
x,
|
x,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
|
@ -72,14 +72,14 @@ def lpt(cosmo,
|
||||||
Computes first and second order LPT displacement and momentum,
|
Computes first and second order LPT displacement and momentum,
|
||||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||||
"""
|
"""
|
||||||
paint_particles = particles is not None
|
paint_absolute_pos = particles is not None
|
||||||
|
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||||
delta_k = fft3d(initial_conditions)
|
delta_k = fft3d(initial_conditions)
|
||||||
initial_force = pm_forces(particles,
|
initial_force = pm_forces(particles,
|
||||||
delta=delta_k,
|
delta=delta_k,
|
||||||
paint_particles=paint_particles,
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
dx = growth_factor(cosmo, a) * initial_force
|
dx = growth_factor(cosmo, a) * initial_force
|
||||||
|
@ -111,7 +111,7 @@ def lpt(cosmo,
|
||||||
delta_k2 = fft3d(delta2)
|
delta_k2 = fft3d(delta2)
|
||||||
init_force2 = pm_forces(particles,
|
init_force2 = pm_forces(particles,
|
||||||
delta=delta_k2,
|
delta=delta_k2,
|
||||||
paint_particles=paint_particles,
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||||
|
@ -144,18 +144,20 @@ def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
||||||
return field
|
return field
|
||||||
|
|
||||||
|
|
||||||
def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None):
|
def make_ode_fn(mesh_shape,
|
||||||
|
paint_absolute_pos=True,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None):
|
||||||
|
|
||||||
def nbody_ode(state, a, cosmo):
|
def nbody_ode(state, a, cosmo):
|
||||||
"""
|
"""
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
paint_particles = particles is not None
|
|
||||||
|
|
||||||
forces = pm_forces(pos,
|
forces = pm_forces(pos,
|
||||||
mesh_shape=mesh_shape,
|
mesh_shape=mesh_shape,
|
||||||
paint_particles=paint_particles,
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
|
@ -165,6 +167,8 @@ def make_ode_fn(mesh_shape, particles=None, halo_size=0, sharding=None):
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
|
#dpos = dpos if not paint_absolute_pos else dpos + pos
|
||||||
|
|
||||||
return dpos, dvel
|
return dpos, dvel
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
Loading…
Add table
Reference in a new issue