remove deprecated stuff

This commit is contained in:
Francois Lanusse 2024-10-24 16:36:41 -04:00 committed by Wassim KABALAN
parent 8e8e8964be
commit 0f833f0cb4
8 changed files with 96 additions and 831 deletions

View file

@ -50,6 +50,8 @@ def cic_paint_impl(grid_mesh, displacement, weight=None):
@partial(jax.jit, static_argnums=(2, 3, 4))
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
@ -63,6 +65,8 @@ def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
print(f"shape of grid_mesh: {grid_mesh.shape}")
return grid_mesh
@ -97,19 +101,20 @@ def cic_read_impl(mesh, displacement):
@partial(jax.jit, static_argnums=(2, 3))
def cic_read(mesh, displacement, halo_size=0, sharding=None):
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
mesh = slice_pad(mesh, halo_size, sharding=sharding)
mesh = halo_exchange(mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
displacement = autoshmap(cic_read_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec),
out_specs=spec)(mesh, displacement)
out_specs=spec)(grid_mesh, positions)
print(f"shape of displacement: {displacement.shape}")
return displacement

View file

@ -18,18 +18,29 @@ def pm_forces(positions,
mesh_shape=None,
delta=None,
r_split=0,
paint_particles=False,
halo_size=0,
sharding=None):
"""
Computes gravitational forces on particles using a PM scheme
"""
print(f"pm_forces particles are {positions}")
original_shape = positions.shape
if mesh_shape is None:
assert (delta is not None),\
"If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape
positions = positions.reshape((*mesh_shape, 3))
if paint_particles:
paint_fn = partial(cic_paint, grid_mesh=jnp.zeros(mesh_shape))
read_fn = partial(cic_read, positions=positions)
else:
paint_fn = cic_paint_dx
read_fn = cic_read_dx
if delta is None:
field = cic_paint_dx(positions, halo_size=halo_size, sharding=sharding)
field = paint_fn(positions, halo_size=halo_size, sharding=sharding)
delta_k = fft3d(field)
elif jnp.isrealobj(delta):
delta_k = fft3d(delta)
@ -42,33 +53,44 @@ def pm_forces(positions,
kvec, r_split=r_split)
# Computes gravitational forces
forces = jnp.stack([
cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k),
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),
halo_size=halo_size,
sharding=sharding) for i in range(3)], axis=-1) # yapf: disable
return forces
def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
def lpt(cosmo,
initial_conditions,
particles=None,
a=0.1,
halo_size=0,
sharding=None,
order=1):
"""
Computes first and second order LPT displacement and momentum,
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
"""
print(f"particles are {particles}")
gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P()
local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable
displacement = autoshmap(
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
gpu_mesh=gpu_mesh,
in_specs=(),
out_specs=spec)() # yapf: disable
paint_particles = True
original_shape = particles.shape if particles is not None else (*initial_conditions.shape, 3) # yapf: disable
if particles is None:
paint_particles = False
particles = autoshmap(
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
gpu_mesh=gpu_mesh,
in_specs=(),
out_specs=spec)() # yapf: disable
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = fft3d(initial_conditions)
initial_force = pm_forces(displacement,
initial_force = pm_forces(particles,
delta=delta_k,
paint_particles=paint_particles,
halo_size=halo_size,
sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force
@ -100,6 +122,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
delta_k2 = fft3d(delta2)
init_force2 = pm_forces(displacement,
delta=delta_k2,
paint_particles=paint_particles,
halo_size=halo_size,
sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
@ -111,7 +134,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
p += p2
f += f2
return dx, p, f
return dx.reshape(original_shape), p, f
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -22,16 +22,16 @@ from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn
size = 64
size = 256
mesh_shape = [size] * 3
box_size = [float(size)] * 3
snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 4
halo_size = 32
pdims = (1, 1)
mesh = None
sharding = None
if jax.device_count() > 1:
pdims = (2, 4)
pdims = (2, 2)
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))

View file

@ -27,7 +27,7 @@ def initialize_distributed():
on_cluster = False
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ[
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax

View file

@ -7,5 +7,5 @@ setup(
author='JaxPM developers',
description='A dead simple FastPM implementation in JAX',
packages=find_packages(),
install_requires=['jax', 'jax_cosmo'],
install_requires=['jax', 'jax_cosmo','jaxdecomp'],
)