From 9f494da317bbe03cfbbb579ad2b8ec28ed000312 Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Fri, 28 Feb 2025 13:47:43 +0100 Subject: [PATCH] format --- jaxpm/painting.py | 10 +-- tests/test_against_fpm.py | 3 +- tests/test_distributed_pm.py | 124 ++++++++++++++++------------------- 3 files changed, 62 insertions(+), 75 deletions(-) diff --git a/jaxpm/painting.py b/jaxpm/painting.py index f8797f2..78d63ef 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -30,8 +30,7 @@ def _cic_paint_impl(grid_mesh, positions, weight=1.): if jnp.isscalar(weight): kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) else: - kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), - kernel) + kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel) neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 8, 3]).astype('int32'), @@ -158,7 +157,10 @@ def cic_paint_2d(mesh, positions, weight): return mesh -def _cic_paint_dx_impl(displacements, weight=1. , halo_size=0 , chunk_size=2**24): +def _cic_paint_dx_impl(displacements, + weight=1., + halo_size=0, + chunk_size=2**24): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] @@ -203,7 +205,7 @@ def cic_paint_dx(displacements, chunk_size=chunk_size), gpu_mesh=gpu_mesh, in_specs=(spec, weight_spec), - out_specs=spec)(displacements , weight) + out_specs=spec)(displacements, weight) grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents, diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py index e02eff4..5ef5211 100644 --- a/tests/test_against_fpm.py +++ b/tests/test_against_fpm.py @@ -121,8 +121,7 @@ def test_nbody_relative(simulation_config, initial_conditions, # Initial displacement dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) - ode_fn = ODETerm( - make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) + ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) solver = Dopri5() controller = PIDController(rtol=1e-9, diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index aa68d54..6e94f56 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -2,8 +2,11 @@ from conftest import initialize_distributed initialize_distributed() # ignore : E402 +from functools import partial # noqa : E402 + import jax # noqa : E402 import jax.numpy as jnp # noqa : E402 +import jax_cosmo as jc # noqa : E402 import pytest # noqa : E402 from diffrax import SaveAt # noqa : E402 from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve @@ -12,30 +15,30 @@ from jax import lax # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P # noqa : E402 +from jaxdecomp import get_fft_output_sharding from jaxpm.distributed import uniform_particles # noqa : E402 +from jaxpm.distributed import fft3d, ifft3d from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 -from jaxpm.pm import lpt, make_diffrax_ode , pm_forces # noqa : E402 -from functools import partial # noqa : E402 -import jax_cosmo as jc # noqa : E402 -from jaxpm.distributed import fft3d , ifft3d -from jaxdecomp import get_fft_output_sharding +from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402 + _TOLERANCE = 1e-1 # 🙃🙃 -pdims = [(1, 8) , (8 , 1) , (4 , 2), (2 , 4)] +pdims = [(1, 8), (8, 1), (4, 2), (2, 4)] + @pytest.mark.distributed @pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("pdims", pdims) @pytest.mark.parametrize("absolute_painting", [True, False]) -def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,pdims, - absolute_painting): +def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, + pdims, absolute_painting): if absolute_painting: pytest.skip("Absolute painting is not recommended in distributed mode") painting_str = "absolute" if absolute_painting else "relative" - print("="*50) + print("=" * 50) print(f"Running with {painting_str} painting and pdims {pdims} ...") mesh_shape, box_shape = simulation_config @@ -170,46 +173,40 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,pdims assert mse < _TOLERANCE - @pytest.mark.distributed @pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("pdims", pdims) def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, - order, nbody_from_lpt1, nbody_from_lpt2 , pdims): - + order, nbody_from_lpt1, nbody_from_lpt2, pdims): mesh_shape, box_shape = simulation_config # SINGLE DEVICE RUN cosmo._workspace = {} - mesh = jax.make_mesh(pdims, ('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 - initial_conditions = lax.with_sharding_constraint(initial_conditions, sharding) print(f"sharded initial conditions {initial_conditions.sharding}") cosmo._workspace = {} - @jax.jit def forward_model(initial_conditions, cosmo): - dx, p, _ = lpt(cosmo, - initial_conditions, - a=0.1, - order=order, - halo_size=halo_size, - sharding=sharding) + initial_conditions, + a=0.1, + order=order, + halo_size=halo_size, + sharding=sharding) ode_fn = ODETerm( make_diffrax_ode(mesh_shape, - paint_absolute_pos=False, - halo_size=halo_size, - sharding=sharding)) + paint_absolute_pos=False, + halo_size=halo_size, + sharding=sharding)) y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p) solver = Dopri5() @@ -219,7 +216,6 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, icoeff=1, dcoeff=0) - saveat = SaveAt(t1=True) solutions = diffeqsolve(ode_fn, solver, @@ -231,15 +227,12 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, stepsize_controller=controller, saveat=saveat) - - multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], halo_size=halo_size, sharding=sharding) return multi_device_final_field - @jax.jit def model(initial_conditions, cosmo): final_field = forward_model(initial_conditions, cosmo) @@ -251,38 +244,33 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, shifted_initial_conditions = initial_conditions + jax.random.normal( jax.random.key(42), initial_conditions.shape) * 5 - good_grads = jax.grad(model)(initial_conditions, cosmo) off_grads = jax.grad(model)(shifted_initial_conditions, cosmo) - - assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) - assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) - + assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding, + ndim=3) + assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding, + ndim=3) @pytest.mark.distributed @pytest.mark.parametrize("pdims", pdims) -def test_fwd_rev_gradients(cosmo,pdims): - +def test_fwd_rev_gradients(cosmo, pdims): mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) # SINGLE DEVICE RUN cosmo._workspace = {} - mesh = jax.make_mesh(pdims, ('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 - initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape) initial_conditions = lax.with_sharding_constraint(initial_conditions, sharding) print(f"sharded initial conditions {initial_conditions.sharding}") cosmo._workspace = {} - @partial(jax.jit, static_argnums=(2, 3, 4)) def compute_forces(initial_conditions, cosmo, @@ -290,16 +278,15 @@ def test_fwd_rev_gradients(cosmo,pdims): halo_size=0, sharding=None): - paint_absolute_pos = False particles = jnp.zeros_like(initial_conditions, shape=(*initial_conditions.shape, 3)) - a = jnp.atleast_1d(a) E = jnp.sqrt(jc.background.Esqr(cosmo, a)) - - initial_conditions = jax.lax.with_sharding_constraint(initial_conditions,sharding) + + initial_conditions = jax.lax.with_sharding_constraint( + initial_conditions, sharding) delta_k = fft3d(initial_conditions) out_sharding = get_fft_output_sharding(sharding) delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding) @@ -310,10 +297,8 @@ def test_fwd_rev_gradients(cosmo,pdims): halo_size=halo_size, sharding=sharding) - return initial_force[..., 0] - forces = compute_forces(initial_conditions, cosmo, halo_size=halo_size, @@ -327,41 +312,44 @@ def test_fwd_rev_gradients(cosmo,pdims): halo_size=halo_size, sharding=sharding) - print(f"Forces sharding is {forces.sharding}") print(f"Backward gradient sharding is {back_gradient.sharding}") print(f"Forward gradient sharding is {fwd_gradient.sharding}") - assert forces.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) - assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) - assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) - + assert forces.sharding.is_equivalent_to(initial_conditions.sharding, + ndim=3) + assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to( + initial_conditions.sharding, ndim=3) + assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding, + ndim=3) @pytest.mark.distributed @pytest.mark.parametrize("pdims", pdims) -def test_vmap(cosmo,pdims): +def test_vmap(cosmo, pdims): mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) # SINGLE DEVICE RUN cosmo._workspace = {} - mesh = jax.make_mesh(pdims, ('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 + single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42), + mesh_shape) + initial_conditions = lax.with_sharding_constraint( + single_dev_initial_conditions, sharding) - single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape) - initial_conditions = lax.with_sharding_constraint(single_dev_initial_conditions, - sharding) - - single_ics = jnp.stack([single_dev_initial_conditions, single_dev_initial_conditions , single_dev_initial_conditions]) - sharded_ics = jnp.stack([initial_conditions, initial_conditions , initial_conditions]) + single_ics = jnp.stack([ + single_dev_initial_conditions, single_dev_initial_conditions, + single_dev_initial_conditions + ]) + sharded_ics = jnp.stack( + [initial_conditions, initial_conditions, initial_conditions]) print(f"unsharded initial conditions batch {single_ics.sharding}") print(f"sharded initial conditions batch {sharded_ics.sharding}") cosmo._workspace = {} - @partial(jax.jit, static_argnums=(2, 3, 4)) def compute_forces(initial_conditions, cosmo, @@ -369,16 +357,15 @@ def test_vmap(cosmo,pdims): halo_size=0, sharding=None): - paint_absolute_pos = False particles = jnp.zeros_like(initial_conditions, shape=(*initial_conditions.shape, 3)) - a = jnp.atleast_1d(a) E = jnp.sqrt(jc.background.Esqr(cosmo, a)) - - initial_conditions = jax.lax.with_sharding_constraint(initial_conditions,sharding) + + initial_conditions = jax.lax.with_sharding_constraint( + initial_conditions, sharding) delta_k = fft3d(initial_conditions) out_sharding = get_fft_output_sharding(sharding) delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding) @@ -389,14 +376,13 @@ def test_vmap(cosmo,pdims): halo_size=halo_size, sharding=sharding) - return initial_force[..., 0] def fn(ic): return compute_forces(ic, - cosmo, - halo_size=halo_size, - sharding=sharding) + cosmo, + halo_size=halo_size, + sharding=sharding) v_compute_forces = jax.vmap(fn) @@ -409,8 +395,8 @@ def test_vmap(cosmo,pdims): assert single_dev_forces.ndim == 4 assert sharded_forces.ndim == 4 - print(f"Sharded forces {sharded_forces.sharding}") - assert sharded_forces[0].sharding.is_equivalent_to(initial_conditions.sharding , ndim=3) - assert sharded_forces.sharding.spec[0] == None \ No newline at end of file + assert sharded_forces[0].sharding.is_equivalent_to( + initial_conditions.sharding, ndim=3) + assert sharded_forces.sharding.spec[0] == None