diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 78d63ef..3083f08 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -12,7 +12,7 @@ from jaxpm.kernels import cic_compensation, fftk from jaxpm.painting_utils import gather, scatter -def _cic_paint_impl(grid_mesh, positions, weight=1.): +def _cic_paint_impl(grid_mesh, positions, weight=None): """ Paints positions onto mesh mesh: [nx, ny, nz] displacement field: [nx, ny, nz, 3] @@ -27,10 +27,12 @@ def _cic_paint_impl(grid_mesh, positions, weight=1.): neighboor_coords = floor + connection kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] - if jnp.isscalar(weight): - kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) - else: - kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel) + if weight is not None: + if jnp.isscalar(weight): + kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel) + else: + kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), + kernel) neighboor_coords = jnp.mod( neighboor_coords.reshape([-1, 8, 3]).astype('int32'), @@ -46,13 +48,7 @@ def _cic_paint_impl(grid_mesh, positions, weight=1.): @partial(jax.jit, static_argnums=(3, 4)) -def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None): - - if sharding is not None: - print(""" - WARNING : absolute painting is not recommended in multi-device mode. - Please use relative painting instead. - """) +def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): positions = positions.reshape((*grid_mesh.shape, 3)) @@ -61,11 +57,9 @@ def cic_paint(grid_mesh, positions, weight=1., halo_size=0, sharding=None): gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None spec = sharding.spec if isinstance(sharding, NamedSharding) else P() - weight_spec = P() if jnp.isscalar(weight) else spec - grid_mesh = autoshmap(_cic_paint_impl, gpu_mesh=gpu_mesh, - in_specs=(spec, spec, weight_spec), + in_specs=(spec, spec, P()), out_specs=spec)(grid_mesh, positions, weight) grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents, @@ -157,10 +151,7 @@ def cic_paint_2d(mesh, positions, weight): return mesh -def _cic_paint_dx_impl(displacements, - weight=1., - halo_size=0, - chunk_size=2**24): +def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] @@ -199,13 +190,13 @@ def cic_paint_dx(displacements, gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None spec = sharding.spec if isinstance(sharding, NamedSharding) else P() - weight_spec = P() if jnp.isscalar(weight) else spec grid_mesh = autoshmap(partial(_cic_paint_dx_impl, halo_size=halo_size, + weight=weight, chunk_size=chunk_size), gpu_mesh=gpu_mesh, - in_specs=(spec, weight_spec), - out_specs=spec)(displacements, weight) + in_specs=spec, + out_specs=spec)(displacements) grid_mesh = halo_exchange(grid_mesh, halo_extents=halo_extents, diff --git a/notebooks/README.md b/notebooks/README.md index 872fdd4..43d9d0b 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -37,50 +37,3 @@ Each notebook includes installation instructions and guidelines for configuring - **SLURM** for job scheduling on clusters (if running multi-host setups) > **Note**: These notebooks are tested on the **Jean Zay** supercomputer and may require configuration changes for different HPC clusters. - -## Caveats - -### Cloud-in-Cell (CIC) Painting (Single Device) - -There is two ways to perform the CIC painting in JAXPM. The first one is to use the `cic_paint` which paints absolute particle positions to the mesh. The second one is to use the `cic_paint_dx` which paints relative particle positions to the mesh (using uniform particles). The absolute version is faster at the cost of more memory usage. - -inorder to use relative painting you need to : - - - Set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` - - Set `paint_absolute_pos` to `False` in `make_ode_fn` or `make_diffrax_ode` function from `jaxpm.pm` (it is True by default) - -Otherwise you set `particles` to the starting particles of your choice and leave `paint_absolute_pos` to `True` (default value). - -### Cloud-in-Cell (CIC) Painting (Multi Device) - -Both `cic_paint` and `cic_paint_dx` functions are available in multi-device mode. - -You need to set the arguments `sharding` and `halo_size` which is explained in the notebook [03-MultiGPU_PM_Halo.ipynb](03-MultiGPU_PM_Halo.ipynb). - -One thing to note that `cic_paint` is not as accurate as `cic_paint_dx` in multi-device mode and therefor is not recommended. - -Using relative painting in multi-device mode is just like in single device mode.\ -You need to set the `particles` argument in `lpt` function from `jaxpm.pm` to `None` and set `paint_absolute_pos` to `False` - -### Distributed PM - -To run a distributed PM follow the examples in notebooks [03](03-MultiGPU_PM_Halo.ipynb) and [05](05-MultiHost_PM.ipynb) for multi-host. - -In short you need to set the arguments `sharding` and `halo_size` in `lpt` , `linear_field` the `make_ode` functions and `pm_forces` if you use it. - -Missmatching the shardings will give you errors and unexpected results. - -You can also use `normal_field` and `uniform_particles` from `jaxpm.pm.distributed` to create the fields and particles with a sharding. - -### Choosing the right pdims - -pdims are processor dimensions.\ -Explained more in the jaxdecomp paper [here](https://github.com/DifferentiableUniverseInitiative/jaxDecomp). - -For 8 devices there are three decompositions that are possible: -- (1 , 8) -- (2 , 4) , (4 , 2) -- (8 , 1) - -(1 , X) should be the fastest (2 , X) or (X , 2) is more accurate but slightly slower.\ -and (X , 1) is giving the least accurate results for some reason so it is not recommended. diff --git a/tests/test_against_fpm.py b/tests/test_against_fpm.py index 5ebcbc2..6d17939 100644 --- a/tests/test_against_fpm.py +++ b/tests/test_against_fpm.py @@ -76,7 +76,7 @@ def test_nbody_absolute(simulation_config, initial_conditions, a=lpt_scale_factor, order=order) - ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) solver = Dopri5() controller = PIDController(rtol=1e-8, @@ -95,7 +95,6 @@ def test_nbody_absolute(simulation_config, initial_conditions, t1=1.0, dt0=None, y0=y0, - args=cosmo, stepsize_controller=controller, saveat=saveat) @@ -122,7 +121,8 @@ def test_nbody_relative(simulation_config, initial_conditions, # Initial displacement dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) - ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) + ode_fn = ODETerm( + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) solver = Dopri5() controller = PIDController(rtol=1e-9, @@ -141,7 +141,6 @@ def test_nbody_relative(simulation_config, initial_conditions, t1=1.0, dt0=None, y0=y0, - args=cosmo, stepsize_controller=controller, saveat=saveat) diff --git a/tests/test_distributed_pm.py b/tests/test_distributed_pm.py index 69c37ed..eb44456 100644 --- a/tests/test_distributed_pm.py +++ b/tests/test_distributed_pm.py @@ -2,11 +2,8 @@ from conftest import initialize_distributed initialize_distributed() # ignore : E402 -from functools import partial # noqa : E402 - import jax # noqa : E402 import jax.numpy as jnp # noqa : E402 -import jax_cosmo as jc # noqa : E402 import pytest # noqa : E402 from diffrax import SaveAt # noqa : E402 from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve @@ -15,31 +12,19 @@ from jax import lax # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P # noqa : E402 -from jaxdecomp import get_fft_output_sharding from jaxpm.distributed import uniform_particles # noqa : E402 -from jaxpm.distributed import fft3d, ifft3d from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402 -from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402 +from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402 -_TOLERANCE = 1e-1 # 🙃🙃 - -pdims = [(1, 8), (8, 1), (4, 2), (2, 4)] +_TOLERANCE = 3.0 # 🙃🙃 @pytest.mark.distributed @pytest.mark.parametrize("order", [1, 2]) -@pytest.mark.parametrize("pdims", pdims) @pytest.mark.parametrize("absolute_painting", [True, False]) def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, - pdims, absolute_painting): - - if absolute_painting: - pytest.skip("Absolute painting is not recommended in distributed mode") - - painting_str = "absolute" if absolute_painting else "relative" - print("=" * 50) - print(f"Running with {painting_str} painting and pdims {pdims} ...") + absolute_painting): mesh_shape, box_shape = simulation_config # SINGLE DEVICE RUN @@ -75,7 +60,6 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, t1=1.0, dt0=None, y0=y0, - args=cosmo, stepsize_controller=controller, saveat=saveat) @@ -88,7 +72,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, print("Done with single device run") # MULTI DEVICE RUN - mesh = jax.make_mesh(pdims, ('x', 'y')) + mesh = jax.make_mesh((1, 8), ('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y')) halo_size = mesh_shape[0] // 2 @@ -144,23 +128,16 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, t1=1.0, dt0=None, y0=y0, - args=cosmo, stepsize_controller=controller, saveat=saveat) - final_field = solutions.ys[-1, 0] - print(f"Final field sharding is {final_field.sharding}") - - assert final_field.sharding.is_equivalent_to(sharding , ndim=3) \ - , f"Final field sharding is not correct .. should be {sharding} it is instead {final_field.sharding}" - if absolute_painting: multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape), - final_field, + solutions.ys[-1, 0], halo_size=halo_size, sharding=sharding) else: - multi_device_final_field = cic_paint_dx(final_field, + multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], halo_size=halo_size, sharding=sharding) @@ -171,230 +148,3 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, print(f"MSE is {mse}") assert mse < _TOLERANCE - - -@pytest.mark.distributed -@pytest.mark.parametrize("order", [1, 2]) -@pytest.mark.parametrize("pdims", pdims) -def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, - order, nbody_from_lpt1, nbody_from_lpt2, pdims): - - mesh_shape, box_shape = simulation_config - # SINGLE DEVICE RUN - cosmo._workspace = {} - - mesh = jax.make_mesh(pdims, ('x', 'y')) - sharding = NamedSharding(mesh, P('x', 'y')) - halo_size = mesh_shape[0] // 2 - - initial_conditions = lax.with_sharding_constraint(initial_conditions, - sharding) - - print(f"sharded initial conditions {initial_conditions.sharding}") - cosmo._workspace = {} - - @jax.jit - def forward_model(initial_conditions, cosmo): - - dx, p, _ = lpt(cosmo, - initial_conditions, - a=0.1, - order=order, - halo_size=halo_size, - sharding=sharding) - ode_fn = ODETerm( - make_diffrax_ode(mesh_shape, - paint_absolute_pos=False, - halo_size=halo_size, - sharding=sharding)) - y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p) - - solver = Dopri5() - controller = PIDController(rtol=1e-8, - atol=1e-8, - pcoeff=0.4, - icoeff=1, - dcoeff=0) - - saveat = SaveAt(t1=True) - solutions = diffeqsolve(ode_fn, - solver, - t0=0.1, - t1=1.0, - dt0=None, - y0=y0, - args=cosmo, - stepsize_controller=controller, - saveat=saveat) - - multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0], - halo_size=halo_size, - sharding=sharding) - - return multi_device_final_field - - @jax.jit - def model(initial_conditions, cosmo): - final_field = forward_model(initial_conditions, cosmo) - return MSE(final_field, - nbody_from_lpt1 if order == 1 else nbody_from_lpt2) - - obs_val = model(initial_conditions, cosmo) - - shifted_initial_conditions = initial_conditions + jax.random.normal( - jax.random.key(42), initial_conditions.shape) * 5 - - good_grads = jax.grad(model)(initial_conditions, cosmo) - off_grads = jax.grad(model)(shifted_initial_conditions, cosmo) - - assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding, - ndim=3) - assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding, - ndim=3) - - -@pytest.mark.distributed -@pytest.mark.parametrize("pdims", pdims) -def test_fwd_rev_gradients(cosmo, pdims): - - mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) - cosmo._workspace = {} - - mesh = jax.make_mesh(pdims, ('x', 'y')) - sharding = NamedSharding(mesh, P('x', 'y')) - halo_size = mesh_shape[0] // 2 - - initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape) - initial_conditions = lax.with_sharding_constraint(initial_conditions, - sharding) - print(f"sharded initial conditions {initial_conditions.sharding}") - cosmo._workspace = {} - - @partial(jax.jit, static_argnums=(2, 3, 4)) - def compute_forces(initial_conditions, - cosmo, - a=0.5, - halo_size=0, - sharding=None): - - paint_absolute_pos = False - particles = jnp.zeros_like(initial_conditions, - shape=(*initial_conditions.shape, 3)) - - a = jnp.atleast_1d(a) - E = jnp.sqrt(jc.background.Esqr(cosmo, a)) - - initial_conditions = jax.lax.with_sharding_constraint( - initial_conditions, sharding) - delta_k = fft3d(initial_conditions) - out_sharding = get_fft_output_sharding(sharding) - delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding) - - initial_force = pm_forces(particles, - delta=delta_k, - paint_absolute_pos=paint_absolute_pos, - halo_size=halo_size, - sharding=sharding) - - return initial_force[..., 0] - - forces = compute_forces(initial_conditions, - cosmo, - halo_size=halo_size, - sharding=sharding) - back_gradient = jax.jacrev(compute_forces)(initial_conditions, - cosmo, - halo_size=halo_size, - sharding=sharding) - fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions, - cosmo, - halo_size=halo_size, - sharding=sharding) - - print(f"Forces sharding is {forces.sharding}") - print(f"Backward gradient sharding is {back_gradient.sharding}") - print(f"Forward gradient sharding is {fwd_gradient.sharding}") - assert forces.sharding.is_equivalent_to(initial_conditions.sharding, - ndim=3) - assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to( - initial_conditions.sharding, ndim=3) - assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding, - ndim=3) - - -@pytest.mark.distributed -@pytest.mark.parametrize("pdims", pdims) -def test_vmap(cosmo, pdims): - - mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0) - cosmo._workspace = {} - - mesh = jax.make_mesh(pdims, ('x', 'y')) - sharding = NamedSharding(mesh, P('x', 'y')) - halo_size = mesh_shape[0] // 2 - - single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42), - mesh_shape) - initial_conditions = lax.with_sharding_constraint( - single_dev_initial_conditions, sharding) - - single_ics = jnp.stack([ - single_dev_initial_conditions, single_dev_initial_conditions, - single_dev_initial_conditions - ]) - sharded_ics = jnp.stack( - [initial_conditions, initial_conditions, initial_conditions]) - print(f"unsharded initial conditions batch {single_ics.sharding}") - print(f"sharded initial conditions batch {sharded_ics.sharding}") - cosmo._workspace = {} - - @partial(jax.jit, static_argnums=(2, 3, 4)) - def compute_forces(initial_conditions, - cosmo, - a=0.5, - halo_size=0, - sharding=None): - - paint_absolute_pos = False - particles = jnp.zeros_like(initial_conditions, - shape=(*initial_conditions.shape, 3)) - - a = jnp.atleast_1d(a) - E = jnp.sqrt(jc.background.Esqr(cosmo, a)) - - initial_conditions = jax.lax.with_sharding_constraint( - initial_conditions, sharding) - delta_k = fft3d(initial_conditions) - out_sharding = get_fft_output_sharding(sharding) - delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding) - - initial_force = pm_forces(particles, - delta=delta_k, - paint_absolute_pos=paint_absolute_pos, - halo_size=halo_size, - sharding=sharding) - - return initial_force[..., 0] - - def fn(ic): - return compute_forces(ic, - cosmo, - halo_size=halo_size, - sharding=sharding) - - v_compute_forces = jax.vmap(fn) - - print(f"single_ics shape {single_ics.shape}") - print(f"sharded_ics shape {sharded_ics.shape}") - - single_dev_forces = v_compute_forces(single_ics) - sharded_forces = v_compute_forces(sharded_ics) - - assert single_dev_forces.ndim == 4 - assert sharded_forces.ndim == 4 - - print(f"Sharded forces {sharded_forces.sharding}") - - assert sharded_forces[0].sharding.is_equivalent_to( - initial_conditions.sharding, ndim=3) - assert sharded_forces.sharding.spec[0] == None diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 1f611aa..bb48920 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -39,7 +39,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, particles, a=lpt_scale_factor, order=order) - ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) + ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape)) y0 = jnp.stack([particles + dx, p]) else: @@ -48,7 +48,7 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, a=lpt_scale_factor, order=order) ode_fn = ODETerm( - make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) + make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False)) y0 = jnp.stack([dx, p]) solver = Dopri5() @@ -66,7 +66,6 @@ def test_nbody_grad(simulation_config, initial_conditions, lpt_scale_factor, t1=1.0, dt0=None, y0=y0, - args=cosmo, adjoint=adjoint, stepsize_controller=controller, saveat=saveat)