remove deprecated stuff

This commit is contained in:
Francois Lanusse 2024-10-24 16:36:41 -04:00 committed by Wassim KABALAN
parent 8e8e8964be
commit 0f833f0cb4
8 changed files with 96 additions and 831 deletions

View file

@ -50,6 +50,8 @@ def cic_paint_impl(grid_mesh, displacement, weight=None):
@partial(jax.jit, static_argnums=(2, 3, 4)) @partial(jax.jit, static_argnums=(2, 3, 4))
def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None): def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding) halo_size, halo_extents = get_halo_size(halo_size, sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding) grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
@ -63,6 +65,8 @@ def cic_paint(grid_mesh, positions, halo_size=0, weight=None, sharding=None):
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True)) halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding) grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
print(f"shape of grid_mesh: {grid_mesh.shape}")
return grid_mesh return grid_mesh
@ -97,19 +101,20 @@ def cic_read_impl(mesh, displacement):
@partial(jax.jit, static_argnums=(2, 3)) @partial(jax.jit, static_argnums=(2, 3))
def cic_read(mesh, displacement, halo_size=0, sharding=None): def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
mesh = slice_pad(mesh, halo_size, sharding=sharding) grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
mesh = halo_exchange(mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,
halo_periods=(True, True)) halo_periods=(True, True))
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P() spec = sharding.spec if sharding is not None else P()
displacement = autoshmap(cic_read_impl, displacement = autoshmap(cic_read_impl,
gpu_mesh=gpu_mesh, gpu_mesh=gpu_mesh,
in_specs=(spec, spec), in_specs=(spec, spec),
out_specs=spec)(mesh, displacement) out_specs=spec)(grid_mesh, positions)
print(f"shape of displacement: {displacement.shape}")
return displacement return displacement

View file

@ -18,18 +18,29 @@ def pm_forces(positions,
mesh_shape=None, mesh_shape=None,
delta=None, delta=None,
r_split=0, r_split=0,
paint_particles=False,
halo_size=0, halo_size=0,
sharding=None): sharding=None):
""" """
Computes gravitational forces on particles using a PM scheme Computes gravitational forces on particles using a PM scheme
""" """
print(f"pm_forces particles are {positions}")
original_shape = positions.shape
if mesh_shape is None: if mesh_shape is None:
assert (delta is not None),\ assert (delta is not None),\
"If mesh_shape is not provided, delta should be provided" "If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape mesh_shape = delta.shape
positions = positions.reshape((*mesh_shape, 3))
if paint_particles:
paint_fn = partial(cic_paint, grid_mesh=jnp.zeros(mesh_shape))
read_fn = partial(cic_read, positions=positions)
else:
paint_fn = cic_paint_dx
read_fn = cic_read_dx
if delta is None: if delta is None:
field = cic_paint_dx(positions, halo_size=halo_size, sharding=sharding) field = paint_fn(positions, halo_size=halo_size, sharding=sharding)
delta_k = fft3d(field) delta_k = fft3d(field)
elif jnp.isrealobj(delta): elif jnp.isrealobj(delta):
delta_k = fft3d(delta) delta_k = fft3d(delta)
@ -42,33 +53,44 @@ def pm_forces(positions,
kvec, r_split=r_split) kvec, r_split=r_split)
# Computes gravitational forces # Computes gravitational forces
forces = jnp.stack([ forces = jnp.stack([
cic_read_dx(ifft3d(-gradient_kernel(kvec, i) * pot_k), read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) for i in range(3)], axis=-1) # yapf: disable sharding=sharding) for i in range(3)], axis=-1) # yapf: disable
return forces return forces
def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1): def lpt(cosmo,
initial_conditions,
particles=None,
a=0.1,
halo_size=0,
sharding=None,
order=1):
""" """
Computes first and second order LPT displacement and momentum, Computes first and second order LPT displacement and momentum,
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258) e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
""" """
print(f"particles are {particles}")
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if sharding is not None else None
spec = sharding.spec if sharding is not None else P() spec = sharding.spec if sharding is not None else P()
local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), 3) # yapf: disable
displacement = autoshmap( paint_particles = True
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), original_shape = particles.shape if particles is not None else (*initial_conditions.shape, 3) # yapf: disable
gpu_mesh=gpu_mesh, if particles is None:
in_specs=(), paint_particles = False
out_specs=spec)() # yapf: disable particles = autoshmap(
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
gpu_mesh=gpu_mesh,
in_specs=(),
out_specs=spec)() # yapf: disable
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a)) E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = fft3d(initial_conditions) delta_k = fft3d(initial_conditions)
initial_force = pm_forces(displacement, initial_force = pm_forces(particles,
delta=delta_k, delta=delta_k,
paint_particles=paint_particles,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force dx = growth_factor(cosmo, a) * initial_force
@ -100,6 +122,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
delta_k2 = fft3d(delta2) delta_k2 = fft3d(delta2)
init_force2 = pm_forces(displacement, init_force2 = pm_forces(displacement,
delta=delta_k2, delta=delta_k2,
paint_particles=paint_particles,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second # NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
@ -111,7 +134,7 @@ def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None, order=1):
p += p2 p += p2
f += f2 f += f2
return dx, p, f return dx.reshape(original_shape), p, f
def linear_field(mesh_shape, box_size, pk, seed, sharding=None): def linear_field(mesh_shape, box_size, pk, seed, sharding=None):

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -22,16 +22,16 @@ from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx from jaxpm.painting import cic_paint_dx
from jaxpm.pm import linear_field, lpt, make_ode_fn from jaxpm.pm import linear_field, lpt, make_ode_fn
size = 64 size = 256
mesh_shape = [size] * 3 mesh_shape = [size] * 3
box_size = [float(size)] * 3 box_size = [float(size)] * 3
snapshots = jnp.linspace(0.1, 1., 4) snapshots = jnp.linspace(0.1, 1., 4)
halo_size = 4 halo_size = 32
pdims = (1, 1) pdims = (1, 1)
mesh = None mesh = None
sharding = None sharding = None
if jax.device_count() > 1: if jax.device_count() > 1:
pdims = (2, 4) pdims = (2, 2)
devices = mesh_utils.create_device_mesh(pdims) devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices.T, axis_names=('x', 'y')) mesh = Mesh(devices.T, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))

View file

@ -27,7 +27,7 @@ def initialize_distributed():
on_cluster = False on_cluster = False
os.environ["JAX_PLATFORM_NAME"] = "cpu" os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ[ os.environ[
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" "XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
import jax import jax

View file

@ -7,5 +7,5 @@ setup(
author='JaxPM developers', author='JaxPM developers',
description='A dead simple FastPM implementation in JAX', description='A dead simple FastPM implementation in JAX',
packages=find_packages(), packages=find_packages(),
install_requires=['jax', 'jax_cosmo'], install_requires=['jax', 'jax_cosmo','jaxdecomp'],
) )