mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
remove deprecated stuff
This commit is contained in:
parent
8e8e8964be
commit
0f833f0cb4
8 changed files with 96 additions and 831 deletions
|
@ -50,6 +50,8 @@ def cic_paint_impl(grid_mesh, displacement, weight=None):
|
|||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
||||
|
||||
|
@ -63,6 +65,8 @@ def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
|||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||
|
||||
print(f"shape of grid_mesh: {grid_mesh.shape}")
|
||||
return grid_mesh
|
||||
|
||||
|
||||
|
@ -97,19 +101,20 @@ def cic_read_impl(mesh, displacement):
|
|||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read(mesh, displacement, halo_size=0, sharding=None):
|
||||
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
mesh = slice_pad(mesh, halo_size, sharding=sharding)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
displacement = autoshmap(cic_read_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec),
|
||||
out_specs=spec)(mesh, displacement)
|
||||
out_specs=spec)(grid_mesh, positions)
|
||||
print(f"shape of displacement: {displacement.shape}")
|
||||
|
||||
return displacement
|
||||
|
||||
|
|
45
jaxpm/pm.py
45
jaxpm/pm.py
|
@ -18,18 +18,29 @@ def pm_forces(positions,
|
|||
mesh_shape=None,
|
||||
delta=None,
|
||||
r_split=0,
|
||||
paint_particles=False,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
print(f"pm_forces particles are {positions}")
|
||||
original_shape = positions.shape
|
||||
if mesh_shape is None:
|
||||
assert (delta is not None),\
|
||||
"If mesh_shape is not provided, delta should be provided"
|
||||
mesh_shape = delta.shape
|
||||
|
||||
positions = positions.reshape((*mesh_shape, 3))
|
||||
if paint_particles:
|
||||
paint_fn = partial(cic_paint, grid_mesh=jnp.zeros(mesh_shape))
|
||||
read_fn = partial(cic_read, positions=positions)
|
||||
else:
|
||||
paint_fn = cic_paint_dx
|
||||
read_fn = cic_read_dx
|
||||
|
||||
if delta is None:
|
||||
field = cic_paint_dx(positions, halo_size=halo_size, sharding=sharding)
|
||||
field = paint_fn(positions, halo_size=halo_size, sharding=sharding)
|
||||
delta_k = fft3d(field)
|
||||
elif jnp.isrealobj(delta):
|
||||
delta_k = fft3d(delta)
|
||||
|
@ -42,33 +53,44 @@ def pm_forces(positions,
|
|||
kvec, r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) for i in range(3)], axis=-1) # yapf: disable
|
||||
|
||||
return forces
|
||||
|
||||
|
||||
def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
||||
def lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles=None,
|
||||
a=0.1,
|
||||
halo_size=0,
|
||||
sharding=None,
|
||||
order=1):
|
||||
"""
|
||||
Computes first and second order LPT displacement and momentum,
|
||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||
"""
|
||||
print(f"particles are {particles}")
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable
|
||||
displacement = autoshmap(
|
||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(),
|
||||
out_specs=spec)() # yapf: disable
|
||||
|
||||
paint_particles = True
|
||||
original_shape = particles.shape if particles is not None else (*initial_conditions.shape, 3) # yapf: disable
|
||||
if particles is None:
|
||||
paint_particles = False
|
||||
particles = autoshmap(
|
||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(),
|
||||
out_specs=spec)() # yapf: disable
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
delta_k = fft3d(initial_conditions)
|
||||
initial_force = pm_forces(displacement,
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_particles=paint_particles,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
|
@ -100,6 +122,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
|||
delta_k2 = fft3d(delta2)
|
||||
init_force2 = pm_forces(displacement,
|
||||
delta=delta_k2,
|
||||
paint_particles=paint_particles,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||
|
@ -111,7 +134,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
|||
p += p2
|
||||
f += f2
|
||||
|
||||
return dx, p, f
|
||||
return dx.reshape(original_shape), p, f
|
||||
|
||||
|
||||
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue