This commit is contained in:
Wassim Kabalan 2025-02-28 13:47:43 +01:00
parent 580387ce1c
commit 9f494da317
3 changed files with 62 additions and 75 deletions

View file

@ -30,8 +30,7 @@ def _cic_paint_impl(grid_mesh, positions, weight=1.):
if jnp.isscalar(weight): if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else: else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
kernel)
neighboor_coords = jnp.mod( neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'), neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
@ -158,7 +157,10 @@ def cic_paint_2d(mesh, positions, weight):
return mesh return mesh
def _cic_paint_dx_impl(displacements, weight=1. , halo_size=0 , chunk_size=2**24): def _cic_paint_dx_impl(displacements,
weight=1.,
halo_size=0,
chunk_size=2**24):
halo_x, _ = halo_size[0] halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
@ -203,7 +205,7 @@ def cic_paint_dx(displacements,
chunk_size=chunk_size), chunk_size=chunk_size),
gpu_mesh=gpu_mesh, gpu_mesh=gpu_mesh,
in_specs=(spec, weight_spec), in_specs=(spec, weight_spec),
out_specs=spec)(displacements , weight) out_specs=spec)(displacements, weight)
grid_mesh = halo_exchange(grid_mesh, grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents, halo_extents=halo_extents,

View file

@ -121,8 +121,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-9, controller = PIDController(rtol=1e-9,

View file

@ -2,8 +2,11 @@ from conftest import initialize_distributed
initialize_distributed() # ignore : E402 initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402 import jax # noqa : E402
import jax.numpy as jnp # noqa : E402 import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402 import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402 from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -12,30 +15,30 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402 from jax.sharding import PartitionSpec as P # noqa : E402
from jaxdecomp import get_fft_output_sharding
from jaxpm.distributed import uniform_particles # noqa : E402 from jaxpm.distributed import uniform_particles # noqa : E402
from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode , pm_forces # noqa : E402 from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
from functools import partial # noqa : E402
import jax_cosmo as jc # noqa : E402
from jaxpm.distributed import fft3d , ifft3d
from jaxdecomp import get_fft_output_sharding
_TOLERANCE = 1e-1 # 🙃🙃 _TOLERANCE = 1e-1 # 🙃🙃
pdims = [(1, 8) , (8 , 1) , (4 , 2), (2 , 4)] pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims) @pytest.mark.parametrize("pdims", pdims)
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,pdims, def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting): pdims, absolute_painting):
if absolute_painting: if absolute_painting:
pytest.skip("Absolute painting is not recommended in distributed mode") pytest.skip("Absolute painting is not recommended in distributed mode")
painting_str = "absolute" if absolute_painting else "relative" painting_str = "absolute" if absolute_painting else "relative"
print("="*50) print("=" * 50)
print(f"Running with {painting_str} painting and pdims {pdims} ...") print(f"Running with {painting_str} painting and pdims {pdims} ...")
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
@ -170,46 +173,40 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,pdims
assert mse < _TOLERANCE assert mse < _TOLERANCE
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("pdims", pdims) @pytest.mark.parametrize("pdims", pdims)
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2 , pdims): order, nbody_from_lpt1, nbody_from_lpt2, pdims):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN # SINGLE DEVICE RUN
cosmo._workspace = {} cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y')) mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
initial_conditions = lax.with_sharding_constraint(initial_conditions, initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding) sharding)
print(f"sharded initial conditions {initial_conditions.sharding}") print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {} cosmo._workspace = {}
@jax.jit @jax.jit
def forward_model(initial_conditions, cosmo): def forward_model(initial_conditions, cosmo):
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
initial_conditions, initial_conditions,
a=0.1, a=0.1,
order=order, order=order,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, make_diffrax_ode(mesh_shape,
paint_absolute_pos=False, paint_absolute_pos=False,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p) y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5() solver = Dopri5()
@ -219,7 +216,6 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
icoeff=1, icoeff=1,
dcoeff=0) dcoeff=0)
saveat = SaveAt(t1=True) saveat = SaveAt(t1=True)
solutions = diffeqsolve(ode_fn, solutions = diffeqsolve(ode_fn,
solver, solver,
@ -231,15 +227,12 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
stepsize_controller=controller, stepsize_controller=controller,
saveat=saveat) saveat=saveat)
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
return multi_device_final_field return multi_device_final_field
@jax.jit @jax.jit
def model(initial_conditions, cosmo): def model(initial_conditions, cosmo):
final_field = forward_model(initial_conditions, cosmo) final_field = forward_model(initial_conditions, cosmo)
@ -251,38 +244,33 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
shifted_initial_conditions = initial_conditions + jax.random.normal( shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5 jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions, cosmo) good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo) off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) ndim=3)
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims) @pytest.mark.parametrize("pdims", pdims)
def test_fwd_rev_gradients(cosmo,pdims): def test_fwd_rev_gradients(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
# SINGLE DEVICE RUN # SINGLE DEVICE RUN
cosmo._workspace = {} cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y')) mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape) initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
initial_conditions = lax.with_sharding_constraint(initial_conditions, initial_conditions = lax.with_sharding_constraint(initial_conditions,
sharding) sharding)
print(f"sharded initial conditions {initial_conditions.sharding}") print(f"sharded initial conditions {initial_conditions.sharding}")
cosmo._workspace = {} cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4)) @partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions, def compute_forces(initial_conditions,
cosmo, cosmo,
@ -290,16 +278,15 @@ def test_fwd_rev_gradients(cosmo,pdims):
halo_size=0, halo_size=0,
sharding=None): sharding=None):
paint_absolute_pos = False paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions, particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3)) shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a)) E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(initial_conditions,sharding) initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions) delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding) out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding) delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
@ -310,10 +297,8 @@ def test_fwd_rev_gradients(cosmo,pdims):
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
return initial_force[..., 0] return initial_force[..., 0]
forces = compute_forces(initial_conditions, forces = compute_forces(initial_conditions,
cosmo, cosmo,
halo_size=halo_size, halo_size=halo_size,
@ -327,41 +312,44 @@ def test_fwd_rev_gradients(cosmo,pdims):
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
print(f"Forces sharding is {forces.sharding}") print(f"Forces sharding is {forces.sharding}")
print(f"Backward gradient sharding is {back_gradient.sharding}") print(f"Backward gradient sharding is {back_gradient.sharding}")
print(f"Forward gradient sharding is {fwd_gradient.sharding}") print(f"Forward gradient sharding is {fwd_gradient.sharding}")
assert forces.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) ndim=3)
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
initial_conditions.sharding, ndim=3)
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
ndim=3)
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("pdims", pdims) @pytest.mark.parametrize("pdims", pdims)
def test_vmap(cosmo,pdims): def test_vmap(cosmo, pdims):
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
# SINGLE DEVICE RUN # SINGLE DEVICE RUN
cosmo._workspace = {} cosmo._workspace = {}
mesh = jax.make_mesh(pdims, ('x', 'y')) mesh = jax.make_mesh(pdims, ('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
mesh_shape)
initial_conditions = lax.with_sharding_constraint(
single_dev_initial_conditions, sharding)
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape) single_ics = jnp.stack([
initial_conditions = lax.with_sharding_constraint(single_dev_initial_conditions, single_dev_initial_conditions, single_dev_initial_conditions,
sharding) single_dev_initial_conditions
])
single_ics = jnp.stack([single_dev_initial_conditions, single_dev_initial_conditions , single_dev_initial_conditions]) sharded_ics = jnp.stack(
sharded_ics = jnp.stack([initial_conditions, initial_conditions , initial_conditions]) [initial_conditions, initial_conditions, initial_conditions])
print(f"unsharded initial conditions batch {single_ics.sharding}") print(f"unsharded initial conditions batch {single_ics.sharding}")
print(f"sharded initial conditions batch {sharded_ics.sharding}") print(f"sharded initial conditions batch {sharded_ics.sharding}")
cosmo._workspace = {} cosmo._workspace = {}
@partial(jax.jit, static_argnums=(2, 3, 4)) @partial(jax.jit, static_argnums=(2, 3, 4))
def compute_forces(initial_conditions, def compute_forces(initial_conditions,
cosmo, cosmo,
@ -369,16 +357,15 @@ def test_vmap(cosmo,pdims):
halo_size=0, halo_size=0,
sharding=None): sharding=None):
paint_absolute_pos = False paint_absolute_pos = False
particles = jnp.zeros_like(initial_conditions, particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3)) shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a)) E = jnp.sqrt(jc.background.Esqr(cosmo, a))
initial_conditions = jax.lax.with_sharding_constraint(initial_conditions,sharding) initial_conditions = jax.lax.with_sharding_constraint(
initial_conditions, sharding)
delta_k = fft3d(initial_conditions) delta_k = fft3d(initial_conditions)
out_sharding = get_fft_output_sharding(sharding) out_sharding = get_fft_output_sharding(sharding)
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding) delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
@ -389,14 +376,13 @@ def test_vmap(cosmo,pdims):
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
return initial_force[..., 0] return initial_force[..., 0]
def fn(ic): def fn(ic):
return compute_forces(ic, return compute_forces(ic,
cosmo, cosmo,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
v_compute_forces = jax.vmap(fn) v_compute_forces = jax.vmap(fn)
@ -409,8 +395,8 @@ def test_vmap(cosmo,pdims):
assert single_dev_forces.ndim == 4 assert single_dev_forces.ndim == 4
assert sharded_forces.ndim == 4 assert sharded_forces.ndim == 4
print(f"Sharded forces {sharded_forces.sharding}") print(f"Sharded forces {sharded_forces.sharding}")
assert sharded_forces[0].sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) assert sharded_forces[0].sharding.is_equivalent_to(
assert sharded_forces.sharding.spec[0] == None initial_conditions.sharding, ndim=3)
assert sharded_forces.sharding.spec[0] == None