mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
remove deprecated stuff
This commit is contained in:
parent
8e8e8964be
commit
0f833f0cb4
8 changed files with 96 additions and 831 deletions
|
@ -50,6 +50,8 @@ def cic_paint_impl(grid_mesh, displacement, weight=None):
|
|||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
||||
|
||||
|
@ -63,6 +65,8 @@ def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
|
|||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||
|
||||
print(f"shape of grid_mesh: {grid_mesh.shape}")
|
||||
return grid_mesh
|
||||
|
||||
|
||||
|
@ -97,19 +101,20 @@ def cic_read_impl(mesh, displacement):
|
|||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read(mesh, displacement, halo_size=0, sharding=None):
|
||||
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
mesh = slice_pad(mesh, halo_size, sharding=sharding)
|
||||
mesh = halo_exchange(mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
displacement = autoshmap(cic_read_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec),
|
||||
out_specs=spec)(mesh, displacement)
|
||||
out_specs=spec)(grid_mesh, positions)
|
||||
print(f"shape of displacement: {displacement.shape}")
|
||||
|
||||
return displacement
|
||||
|
||||
|
|
45
jaxpm/pm.py
45
jaxpm/pm.py
|
@ -18,18 +18,29 @@ def pm_forces(positions,
|
|||
mesh_shape=None,
|
||||
delta=None,
|
||||
r_split=0,
|
||||
paint_particles=False,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
print(f"pm_forces particles are {positions}")
|
||||
original_shape = positions.shape
|
||||
if mesh_shape is None:
|
||||
assert (delta is not None),\
|
||||
"If mesh_shape is not provided, delta should be provided"
|
||||
mesh_shape = delta.shape
|
||||
|
||||
positions = positions.reshape((*mesh_shape, 3))
|
||||
if paint_particles:
|
||||
paint_fn = partial(cic_paint, grid_mesh=jnp.zeros(mesh_shape))
|
||||
read_fn = partial(cic_read, positions=positions)
|
||||
else:
|
||||
paint_fn = cic_paint_dx
|
||||
read_fn = cic_read_dx
|
||||
|
||||
if delta is None:
|
||||
field = cic_paint_dx(positions, halo_size=halo_size, sharding=sharding)
|
||||
field = paint_fn(positions, halo_size=halo_size, sharding=sharding)
|
||||
delta_k = fft3d(field)
|
||||
elif jnp.isrealobj(delta):
|
||||
delta_k = fft3d(delta)
|
||||
|
@ -42,33 +53,44 @@ def pm_forces(positions,
|
|||
kvec, r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([
|
||||
cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) for i in range(3)], axis=-1) # yapf: disable
|
||||
|
||||
return forces
|
||||
|
||||
|
||||
def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
||||
def lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles=None,
|
||||
a=0.1,
|
||||
halo_size=0,
|
||||
sharding=None,
|
||||
order=1):
|
||||
"""
|
||||
Computes first and second order LPT displacement and momentum,
|
||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||
"""
|
||||
print(f"particles are {particles}")
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
spec = sharding.spec if sharding is not None else P()
|
||||
local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable
|
||||
displacement = autoshmap(
|
||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(),
|
||||
out_specs=spec)() # yapf: disable
|
||||
|
||||
paint_particles = True
|
||||
original_shape = particles.shape if particles is not None else (*initial_conditions.shape, 3) # yapf: disable
|
||||
if particles is None:
|
||||
paint_particles = False
|
||||
particles = autoshmap(
|
||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(),
|
||||
out_specs=spec)() # yapf: disable
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
delta_k = fft3d(initial_conditions)
|
||||
initial_force = pm_forces(displacement,
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_particles=paint_particles,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
|
@ -100,6 +122,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
|||
delta_k2 = fft3d(delta2)
|
||||
init_force2 = pm_forces(displacement,
|
||||
delta=delta_k2,
|
||||
paint_particles=paint_particles,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||
|
@ -111,7 +134,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
|
|||
p += p2
|
||||
f += f2
|
||||
|
||||
return dx, p, f
|
||||
return dx.reshape(original_shape), p, f
|
||||
|
||||
|
||||
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -22,16 +22,16 @@ from jaxpm.kernels import interpolate_power_spectrum
|
|||
from jaxpm.painting import cic_paint_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
|
||||
size = 64
|
||||
size = 256
|
||||
mesh_shape = [size] * 3
|
||||
box_size = [float(size)] * 3
|
||||
snapshots = jnp.linspace(0.1, 1., 4)
|
||||
halo_size = 4
|
||||
halo_size = 32
|
||||
pdims = (1, 1)
|
||||
mesh = None
|
||||
sharding = None
|
||||
if jax.device_count() > 1:
|
||||
pdims = (2, 4)
|
||||
pdims = (2, 2)
|
||||
devices = mesh_utils.create_device_mesh(pdims)
|
||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
|
|
|
@ -27,7 +27,7 @@ def initialize_distributed():
|
|||
on_cluster = False
|
||||
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||
os.environ[
|
||||
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
|
||||
import jax
|
||||
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -7,5 +7,5 @@ setup(
|
|||
author='JaxPM developers',
|
||||
description='A dead simple FastPM implementation in JAX',
|
||||
packages=find_packages(),
|
||||
install_requires=['jax', 'jax_cosmo'],
|
||||
install_requires=['jax', 'jax_cosmo','jaxdecomp'],
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue