mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-04 11:10:53 +00:00
format
This commit is contained in:
parent
580387ce1c
commit
9f494da317
3 changed files with 62 additions and 75 deletions
|
@ -30,8 +30,7 @@ def _cic_paint_impl(grid_mesh, positions, weight=1.):
|
|||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||
kernel)
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]), kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
|
@ -158,7 +157,10 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def _cic_paint_dx_impl(displacements, weight=1. , halo_size=0 , chunk_size=2**24):
|
||||
def _cic_paint_dx_impl(displacements,
|
||||
weight=1.,
|
||||
halo_size=0,
|
||||
chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
@ -203,7 +205,7 @@ def cic_paint_dx(displacements,
|
|||
chunk_size=chunk_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, weight_spec),
|
||||
out_specs=spec)(displacements , weight)
|
||||
out_specs=spec)(displacements, weight)
|
||||
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
|
|
|
@ -121,8 +121,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-9,
|
||||
|
|
|
@ -2,8 +2,11 @@ from conftest import initialize_distributed
|
|||
|
||||
initialize_distributed() # ignore : E402
|
||||
|
||||
from functools import partial # noqa : E402
|
||||
|
||||
import jax # noqa : E402
|
||||
import jax.numpy as jnp # noqa : E402
|
||||
import jax_cosmo as jc # noqa : E402
|
||||
import pytest # noqa : E402
|
||||
from diffrax import SaveAt # noqa : E402
|
||||
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
||||
|
@ -12,30 +15,30 @@ from jax import lax # noqa : E402
|
|||
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P # noqa : E402
|
||||
from jaxdecomp import get_fft_output_sharding
|
||||
|
||||
from jaxpm.distributed import uniform_particles # noqa : E402
|
||||
from jaxpm.distributed import fft3d, ifft3d
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode , pm_forces # noqa : E402
|
||||
from functools import partial # noqa : E402
|
||||
import jax_cosmo as jc # noqa : E402
|
||||
from jaxpm.distributed import fft3d , ifft3d
|
||||
from jaxdecomp import get_fft_output_sharding
|
||||
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
|
||||
|
||||
_TOLERANCE = 1e-1 # 🙃🙃
|
||||
|
||||
pdims = [(1, 8) , (8 , 1) , (4 , 2), (2 , 4)]
|
||||
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,pdims,
|
||||
absolute_painting):
|
||||
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||
pdims, absolute_painting):
|
||||
|
||||
if absolute_painting:
|
||||
pytest.skip("Absolute painting is not recommended in distributed mode")
|
||||
|
||||
painting_str = "absolute" if absolute_painting else "relative"
|
||||
print("="*50)
|
||||
print("=" * 50)
|
||||
print(f"Running with {painting_str} painting and pdims {pdims} ...")
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
|
@ -170,46 +173,40 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,pdims
|
|||
assert mse < _TOLERANCE
|
||||
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
||||
order, nbody_from_lpt1, nbody_from_lpt2 , pdims):
|
||||
|
||||
order, nbody_from_lpt1, nbody_from_lpt2, pdims):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
|
||||
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
||||
sharding)
|
||||
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
cosmo._workspace = {}
|
||||
|
||||
|
||||
@jax.jit
|
||||
def forward_model(initial_conditions, cosmo):
|
||||
|
||||
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||
|
||||
solver = Dopri5()
|
||||
|
@ -219,7 +216,6 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
|||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
|
@ -231,15 +227,12 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
|||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
|
||||
|
||||
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
return multi_device_final_field
|
||||
|
||||
|
||||
@jax.jit
|
||||
def model(initial_conditions, cosmo):
|
||||
final_field = forward_model(initial_conditions, cosmo)
|
||||
|
@ -251,38 +244,33 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
|||
shifted_initial_conditions = initial_conditions + jax.random.normal(
|
||||
jax.random.key(42), initial_conditions.shape) * 5
|
||||
|
||||
|
||||
good_grads = jax.grad(model)(initial_conditions, cosmo)
|
||||
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
|
||||
|
||||
|
||||
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3)
|
||||
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3)
|
||||
|
||||
assert good_grads.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
assert off_grads.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
def test_fwd_rev_gradients(cosmo,pdims):
|
||||
|
||||
def test_fwd_rev_gradients(cosmo, pdims):
|
||||
|
||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
|
||||
initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
|
||||
initial_conditions = lax.with_sharding_constraint(initial_conditions,
|
||||
sharding)
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
cosmo._workspace = {}
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
|
@ -290,16 +278,15 @@ def test_fwd_rev_gradients(cosmo,pdims):
|
|||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
|
||||
paint_absolute_pos = False
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
||||
initial_conditions = jax.lax.with_sharding_constraint(initial_conditions,sharding)
|
||||
|
||||
initial_conditions = jax.lax.with_sharding_constraint(
|
||||
initial_conditions, sharding)
|
||||
delta_k = fft3d(initial_conditions)
|
||||
out_sharding = get_fft_output_sharding(sharding)
|
||||
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
|
||||
|
@ -310,10 +297,8 @@ def test_fwd_rev_gradients(cosmo,pdims):
|
|||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
|
||||
return initial_force[..., 0]
|
||||
|
||||
|
||||
forces = compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
|
@ -327,41 +312,44 @@ def test_fwd_rev_gradients(cosmo,pdims):
|
|||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
|
||||
print(f"Forces sharding is {forces.sharding}")
|
||||
print(f"Backward gradient sharding is {back_gradient.sharding}")
|
||||
print(f"Forward gradient sharding is {fwd_gradient.sharding}")
|
||||
assert forces.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3)
|
||||
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(initial_conditions.sharding , ndim=3)
|
||||
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding , ndim=3)
|
||||
|
||||
assert forces.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
assert back_gradient[0, 0, 0, ...].sharding.is_equivalent_to(
|
||||
initial_conditions.sharding, ndim=3)
|
||||
assert fwd_gradient.sharding.is_equivalent_to(initial_conditions.sharding,
|
||||
ndim=3)
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("pdims", pdims)
|
||||
def test_vmap(cosmo,pdims):
|
||||
def test_vmap(cosmo, pdims):
|
||||
|
||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
|
||||
|
||||
mesh = jax.make_mesh(pdims, ('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42),
|
||||
mesh_shape)
|
||||
initial_conditions = lax.with_sharding_constraint(
|
||||
single_dev_initial_conditions, sharding)
|
||||
|
||||
single_dev_initial_conditions = jax.random.normal(jax.random.PRNGKey(42), mesh_shape)
|
||||
initial_conditions = lax.with_sharding_constraint(single_dev_initial_conditions,
|
||||
sharding)
|
||||
|
||||
single_ics = jnp.stack([single_dev_initial_conditions, single_dev_initial_conditions , single_dev_initial_conditions])
|
||||
sharded_ics = jnp.stack([initial_conditions, initial_conditions , initial_conditions])
|
||||
single_ics = jnp.stack([
|
||||
single_dev_initial_conditions, single_dev_initial_conditions,
|
||||
single_dev_initial_conditions
|
||||
])
|
||||
sharded_ics = jnp.stack(
|
||||
[initial_conditions, initial_conditions, initial_conditions])
|
||||
print(f"unsharded initial conditions batch {single_ics.sharding}")
|
||||
print(f"sharded initial conditions batch {sharded_ics.sharding}")
|
||||
cosmo._workspace = {}
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3, 4))
|
||||
def compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
|
@ -369,16 +357,15 @@ def test_vmap(cosmo,pdims):
|
|||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
|
||||
paint_absolute_pos = False
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
||||
initial_conditions = jax.lax.with_sharding_constraint(initial_conditions,sharding)
|
||||
|
||||
initial_conditions = jax.lax.with_sharding_constraint(
|
||||
initial_conditions, sharding)
|
||||
delta_k = fft3d(initial_conditions)
|
||||
out_sharding = get_fft_output_sharding(sharding)
|
||||
delta_k = jax.lax.with_sharding_constraint(delta_k, out_sharding)
|
||||
|
@ -389,14 +376,13 @@ def test_vmap(cosmo,pdims):
|
|||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
|
||||
return initial_force[..., 0]
|
||||
|
||||
def fn(ic):
|
||||
return compute_forces(ic,
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
cosmo,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
v_compute_forces = jax.vmap(fn)
|
||||
|
||||
|
@ -409,8 +395,8 @@ def test_vmap(cosmo,pdims):
|
|||
assert single_dev_forces.ndim == 4
|
||||
assert sharded_forces.ndim == 4
|
||||
|
||||
|
||||
print(f"Sharded forces {sharded_forces.sharding}")
|
||||
|
||||
assert sharded_forces[0].sharding.is_equivalent_to(initial_conditions.sharding , ndim=3)
|
||||
assert sharded_forces.sharding.spec[0] == None
|
||||
assert sharded_forces[0].sharding.is_equivalent_to(
|
||||
initial_conditions.sharding, ndim=3)
|
||||
assert sharded_forces.sharding.spec[0] == None
|
||||
|
|
Loading…
Add table
Reference in a new issue