mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-14 12:01:12 +00:00
format
This commit is contained in:
parent
20fe25c324
commit
1f5c619531
10 changed files with 290 additions and 210 deletions
|
@ -174,18 +174,22 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
|||
|
||||
return fpm_mesh
|
||||
|
||||
|
||||
def compare_sharding(sharding1, sharding2):
|
||||
|
||||
def get_axis_size(sharding, idx):
|
||||
axis_name = sharding.spec[idx]
|
||||
if axis_name is None:
|
||||
return 1
|
||||
else:
|
||||
return sharding.mesh.shape[sharding.spec[idx]]
|
||||
|
||||
def get_pdims_from_sharding(sharding):
|
||||
return tuple([get_axis_size(sharding, i) for i in range(len(sharding.spec))])
|
||||
return tuple(
|
||||
[get_axis_size(sharding, i) for i in range(len(sharding.spec))])
|
||||
|
||||
pdims1 = get_pdims_from_sharding(sharding1)
|
||||
pdims2 = get_pdims_from_sharding(sharding2)
|
||||
pdims1 = pdims1 + (1,) * (3 - len(pdims1))
|
||||
pdims2 = pdims2 + (1,) * (3 - len(pdims2))
|
||||
return pdims1 == pdims2
|
||||
pdims1 = pdims1 + (1, ) * (3 - len(pdims1))
|
||||
pdims2 = pdims2 + (1, ) * (3 - len(pdims2))
|
||||
return pdims1 == pdims2
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
import jax
|
||||
import pytest
|
||||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||
from helpers import MSE, MSRE
|
||||
from jax import numpy as jnp
|
||||
|
||||
from jaxdecomp import ShardedArray
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
from jaxpm.pm import lpt, make_diffrax_ode
|
||||
from jaxpm.utils import power_spectrum
|
||||
import jax
|
||||
|
||||
_TOLERANCE = 1e-4
|
||||
_PM_TOLERANCE = 1e-3
|
||||
|
||||
|
@ -17,7 +18,8 @@ _PM_TOLERANCE = 1e-3
|
|||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
|
||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
|
||||
shardedArrayAPI):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
@ -53,7 +55,8 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
|
||||
fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
|
||||
shardedArrayAPI):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
@ -77,12 +80,13 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
assert type(dx) == ShardedArray
|
||||
assert type(lpt_field) == ShardedArray
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||
def test_nbody_absolute(simulation_config, initial_conditions,
|
||||
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||
cosmo, order , shardedArrayAPI):
|
||||
cosmo, order, shardedArrayAPI):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
@ -110,7 +114,8 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p)
|
||||
y0 = jax.tree.map(lambda particles, dx, p: jnp.stack([particles + dx, p]),
|
||||
particles, dx, p)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
|
@ -135,7 +140,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
|
||||
if shardedArrayAPI:
|
||||
assert type(dx) == ShardedArray
|
||||
assert type( solutions.ys[-1, 0]) == ShardedArray
|
||||
assert type(solutions.ys[-1, 0]) == ShardedArray
|
||||
assert type(final_field) == ShardedArray
|
||||
|
||||
|
||||
|
@ -144,7 +149,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||
def test_nbody_relative(simulation_config, initial_conditions,
|
||||
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
|
||||
cosmo, order , shardedArrayAPI):
|
||||
cosmo, order, shardedArrayAPI):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
cosmo._workspace = {}
|
||||
|
@ -155,8 +160,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-9,
|
||||
|
@ -167,7 +171,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]), dx, p)
|
||||
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||
|
||||
solutions = diffeqsolve(ode_fn,
|
||||
solver,
|
||||
|
@ -192,5 +196,5 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
|
||||
if shardedArrayAPI:
|
||||
assert type(dx) == ShardedArray
|
||||
assert type( solutions.ys[-1, 0]) == ShardedArray
|
||||
assert type(solutions.ys[-1, 0]) == ShardedArray
|
||||
assert type(final_field) == ShardedArray
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
from conftest import initialize_distributed , compare_sharding
|
||||
from conftest import compare_sharding, initialize_distributed
|
||||
|
||||
initialize_distributed() # ignore : E402
|
||||
|
||||
from functools import partial # noqa : E402
|
||||
|
||||
import jax # noqa : E402
|
||||
import jax.numpy as jnp # noqa : E402
|
||||
import jax_cosmo as jc # noqa : E402
|
||||
import pytest # noqa : E402
|
||||
from diffrax import SaveAt # noqa : E402
|
||||
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
|
||||
|
@ -12,13 +15,13 @@ from jax import lax # noqa : E402
|
|||
from jax.experimental.multihost_utils import process_allgather # noqa : E402
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P # noqa : E402
|
||||
from jaxpm.pm import pm_forces # noqa : E402
|
||||
from jaxpm.distributed import uniform_particles , fft3d # noqa : E402
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
||||
from jaxdecomp import ShardedArray # noqa : E402
|
||||
from functools import partial # noqa : E402
|
||||
import jax_cosmo as jc # noqa : E402
|
||||
|
||||
from jaxpm.distributed import fft3d, uniform_particles # noqa : E402
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||
from jaxpm.pm import pm_forces # noqa : E402
|
||||
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
|
||||
|
||||
_TOLERANCE = 3.0 # 🙃🙃
|
||||
|
||||
|
||||
|
@ -27,7 +30,7 @@ _TOLERANCE = 3.0 # 🙃🙃
|
|||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
|
||||
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
||||
absolute_painting,shardedArrayAPI):
|
||||
absolute_painting, shardedArrayAPI):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
# SINGLE DEVICE RUN
|
||||
|
@ -42,18 +45,16 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
if shardedArrayAPI:
|
||||
particles = ShardedArray(particles)
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
ic,
|
||||
particles,
|
||||
a=0.1,
|
||||
order=order)
|
||||
dx, p, _ = lpt(cosmo, ic, particles, a=0.1, order=order)
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
|
||||
y0 = jax.tree.map(
|
||||
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
|
||||
dx, p)
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo, ic, a=0.1, order=order)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False))
|
||||
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
|
@ -87,13 +88,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
halo_size = mesh_shape[0] // 2
|
||||
|
||||
ic = lax.with_sharding_constraint(initial_conditions,
|
||||
sharding)
|
||||
ic = lax.with_sharding_constraint(initial_conditions, sharding)
|
||||
|
||||
print(f"sharded initial conditions {ic.sharding}")
|
||||
|
||||
if shardedArrayAPI:
|
||||
ic = ShardedArray(ic , sharding)
|
||||
ic = ShardedArray(ic, sharding)
|
||||
|
||||
cosmo._workspace = {}
|
||||
if absolute_painting:
|
||||
|
@ -110,12 +110,13 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
|
||||
y0 = jax.tree.map(
|
||||
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
|
||||
dx, p)
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo,
|
||||
ic,
|
||||
|
@ -124,12 +125,11 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(
|
||||
mesh_shape,
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
|
||||
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
|
@ -161,7 +161,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
sharding=sharding)
|
||||
|
||||
multi_device_final_field_g = process_allgather(multi_device_final_field,
|
||||
tiled=True)
|
||||
tiled=True)
|
||||
|
||||
single_device_final_field_arr, = jax.tree.leaves(single_device_final_field)
|
||||
multi_device_final_field_arr, = jax.tree.leaves(multi_device_final_field_g)
|
||||
|
@ -170,18 +170,19 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
|
|||
|
||||
if shardedArrayAPI:
|
||||
assert type(multi_device_final_field) == ShardedArray
|
||||
assert compare_sharding(multi_device_final_field.sharding , sharding)
|
||||
assert compare_sharding(multi_device_final_field.initial_sharding , sharding)
|
||||
assert compare_sharding(multi_device_final_field.sharding, sharding)
|
||||
assert compare_sharding(multi_device_final_field.initial_sharding,
|
||||
sharding)
|
||||
|
||||
assert mse < _TOLERANCE
|
||||
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("order", [1, 2])
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, order,nbody_from_lpt1, nbody_from_lpt2,
|
||||
absolute_painting):
|
||||
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
|
||||
order, nbody_from_lpt1, nbody_from_lpt2,
|
||||
absolute_painting):
|
||||
|
||||
mesh_shape, box_shape = simulation_config
|
||||
# SINGLE DEVICE RUN
|
||||
|
@ -196,55 +197,53 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
|
|||
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
|
||||
|
||||
initial_conditions = ShardedArray(initial_conditions , sharding)
|
||||
initial_conditions = ShardedArray(initial_conditions, sharding)
|
||||
|
||||
cosmo._workspace = {}
|
||||
|
||||
@jax.jit
|
||||
def forward_model(initial_conditions , cosmo):
|
||||
|
||||
def forward_model(initial_conditions, cosmo):
|
||||
|
||||
if absolute_painting:
|
||||
particles = uniform_particles(mesh_shape, sharding=sharding)
|
||||
particles = ShardedArray(particles, sharding)
|
||||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
initial_conditions,
|
||||
particles,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(
|
||||
mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
make_diffrax_ode(mesh_shape,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
|
||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
|
||||
y0 = jax.tree.map(
|
||||
lambda particles, dx, p: jnp.stack([particles + dx, p]),
|
||||
particles, dx, p)
|
||||
else:
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
initial_conditions,
|
||||
a=0.1,
|
||||
order=order,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(
|
||||
mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=False,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
atol=1e-8,
|
||||
pcoeff=0.4,
|
||||
icoeff=1,
|
||||
dcoeff=0)
|
||||
|
||||
saveat = SaveAt(t1=True)
|
||||
|
||||
|
@ -260,9 +259,9 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
|
|||
|
||||
if absolute_painting:
|
||||
multi_device_final_field = cic_paint(jnp.zeros(shape=mesh_shape),
|
||||
solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
else:
|
||||
multi_device_final_field = cic_paint_dx(solutions.ys[-1, 0],
|
||||
halo_size=halo_size,
|
||||
|
@ -271,30 +270,31 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
|
|||
return multi_device_final_field
|
||||
|
||||
@jax.jit
|
||||
def model(initial_conditions , cosmo):
|
||||
def model(initial_conditions, cosmo):
|
||||
|
||||
final_field = forward_model(initial_conditions , cosmo)
|
||||
final_field = forward_model(initial_conditions, cosmo)
|
||||
final_field, = jax.tree.leaves(final_field)
|
||||
|
||||
|
||||
return MSE(final_field,
|
||||
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
|
||||
|
||||
obs_val = model(initial_conditions , cosmo)
|
||||
|
||||
shifted_initial_conditions = initial_conditions + jax.random.normal(jax.random.key(42) , initial_conditions.shape) * 5
|
||||
obs_val = model(initial_conditions, cosmo)
|
||||
|
||||
good_grads = jax.grad(model)(initial_conditions , cosmo)
|
||||
off_grads = jax.grad(model)(shifted_initial_conditions , cosmo)
|
||||
shifted_initial_conditions = initial_conditions + jax.random.normal(
|
||||
jax.random.key(42), initial_conditions.shape) * 5
|
||||
|
||||
assert compare_sharding(good_grads.sharding , initial_conditions.sharding)
|
||||
assert compare_sharding(off_grads.sharding , initial_conditions.sharding)
|
||||
good_grads = jax.grad(model)(initial_conditions, cosmo)
|
||||
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
|
||||
|
||||
assert compare_sharding(good_grads.sharding, initial_conditions.sharding)
|
||||
assert compare_sharding(off_grads.sharding, initial_conditions.sharding)
|
||||
|
||||
|
||||
@pytest.mark.distributed
|
||||
@pytest.mark.parametrize("absolute_painting", [True, False])
|
||||
def test_fwd_rev_gradients(cosmo,absolute_painting):
|
||||
def test_fwd_rev_gradients(cosmo, absolute_painting):
|
||||
|
||||
mesh_shape, box_shape = (8 , 8 , 8) , (20.0 , 20.0 , 20.0)
|
||||
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
|
||||
# SINGLE DEVICE RUN
|
||||
cosmo._workspace = {}
|
||||
|
||||
|
@ -308,34 +308,54 @@ def test_fwd_rev_gradients(cosmo,absolute_painting):
|
|||
sharding)
|
||||
|
||||
print(f"sharded initial conditions {initial_conditions.sharding}")
|
||||
initial_conditions = ShardedArray(initial_conditions , sharding)
|
||||
initial_conditions = ShardedArray(initial_conditions, sharding)
|
||||
|
||||
cosmo._workspace = {}
|
||||
|
||||
@partial(jax.jit , static_argnums=(3,4 , 5))
|
||||
def compute_forces(initial_conditions , cosmo , particles=None , a=0.5 , halo_size=0 , sharding=None):
|
||||
|
||||
@partial(jax.jit, static_argnums=(3, 4, 5))
|
||||
def compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
particles=None,
|
||||
a=0.5,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
paint_absolute_pos = particles is not None
|
||||
if particles is None:
|
||||
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
|
||||
shape=(*ic.shape, 3)) , initial_conditions)
|
||||
particles = jax.tree.map(
|
||||
lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
|
||||
initial_conditions)
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
delta_k = fft3d(initial_conditions)
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
return initial_force[...,0]
|
||||
return initial_force[..., 0]
|
||||
|
||||
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding) , sharding) if absolute_painting else None
|
||||
forces = compute_forces(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
|
||||
back_gradient = jax.jacrev(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
|
||||
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
|
||||
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding),
|
||||
sharding) if absolute_painting else None
|
||||
forces = compute_forces(initial_conditions,
|
||||
cosmo,
|
||||
particles=particles,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
|
||||
cosmo,
|
||||
particles=particles,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
|
||||
cosmo,
|
||||
particles=particles,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
assert compare_sharding(forces.sharding , initial_conditions.sharding)
|
||||
assert compare_sharding(back_gradient[0,0,0,...].sharding , initial_conditions.sharding)
|
||||
assert compare_sharding(fwd_gradient.sharding , initial_conditions.sharding)
|
||||
assert compare_sharding(forces.sharding, initial_conditions.sharding)
|
||||
assert compare_sharding(back_gradient[0, 0, 0, ...].sharding,
|
||||
initial_conditions.sharding)
|
||||
assert compare_sharding(fwd_gradient.sharding, initial_conditions.sharding)
|
||||
|
|
|
@ -1,31 +1,31 @@
|
|||
import os
|
||||
|
||||
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||
|
||||
|
||||
import os
|
||||
os.environ["EQX_ON_ERROR"] = "nan"
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
|
||||
diffeqsolve)
|
||||
from jax.debug import visualize_array_sharding
|
||||
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint_dx , cic_read_dx , cic_paint , cic_read
|
||||
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
||||
from functools import partial
|
||||
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
|
||||
from jaxpm.distributed import uniform_particles
|
||||
|
||||
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
|
||||
|
||||
|
||||
|
||||
from jax.experimental.mesh_utils import create_device_mesh
|
||||
from jax.experimental.multihost_utils import process_allgather
|
||||
from jax.sharding import Mesh, NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.kernels import interpolate_power_spectrum
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
|
||||
|
||||
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
|
||||
|
||||
|
||||
all_gather = partial(process_allgather, tiled=False)
|
||||
|
||||
pdims = (2, 4)
|
||||
|
@ -34,8 +34,8 @@ pdims = (2, 4)
|
|||
#sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
sharding = None
|
||||
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from jaxdecomp import ShardedArray
|
||||
|
||||
mesh_shape = 64
|
||||
|
@ -43,42 +43,42 @@ box_size = 64.
|
|||
halo_size = 2
|
||||
snapshots = (0.5, 1.0)
|
||||
|
||||
|
||||
class Params(NamedTuple):
|
||||
omega_c: float
|
||||
sigma8: float
|
||||
initial_conditions : jnp.ndarray
|
||||
initial_conditions: jnp.ndarray
|
||||
|
||||
mesh_shape = (mesh_shape,) * 3
|
||||
box_size = (box_size,) * 3
|
||||
|
||||
mesh_shape = (mesh_shape, ) * 3
|
||||
box_size = (box_size, ) * 3
|
||||
omega_c = 0.25
|
||||
sigma8 = 0.8
|
||||
# Create a small function to generate the matter power spectrum
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(
|
||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8),
|
||||
k)
|
||||
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
|
||||
|
||||
initial_conditions = linear_field(mesh_shape,
|
||||
box_size,
|
||||
pk_fn,
|
||||
seed=jax.random.PRNGKey(0),
|
||||
sharding=sharding)
|
||||
|
||||
box_size,
|
||||
pk_fn,
|
||||
seed=jax.random.PRNGKey(0),
|
||||
sharding=sharding)
|
||||
|
||||
#initial_conditions = ShardedArray(initial_conditions, sharding)
|
||||
|
||||
params = Params(omega_c, sigma8, initial_conditions)
|
||||
|
||||
|
||||
|
||||
@partial(jax.jit , static_argnums=(1 , 2,3,4 ))
|
||||
def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
|
||||
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
|
||||
def forward_model(params, mesh_shape, box_size, halo_size, snapshots):
|
||||
|
||||
# Create initial conditions
|
||||
cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8)
|
||||
particles = uniform_particles(mesh_shape , sharding)
|
||||
particles = uniform_particles(mesh_shape, sharding)
|
||||
ic_structure = jax.tree.structure(params.initial_conditions)
|
||||
particles = jax.tree.unflatten(ic_structure , jax.tree.leaves(particles))
|
||||
particles = jax.tree.unflatten(ic_structure, jax.tree.leaves(particles))
|
||||
# Initial displacement
|
||||
dx, p, f = lpt(cosmo,
|
||||
params.initial_conditions,
|
||||
|
@ -90,10 +90,15 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
|
|||
|
||||
# Evolve the simulation forward
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(mesh_shape, paint_absolute_pos=True,halo_size=halo_size,sharding=sharding))
|
||||
make_diffrax_ode(mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding))
|
||||
solver = LeapfrogMidpoint()
|
||||
|
||||
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx ,p],axis=0) , particles , dx , p)
|
||||
y0 = jax.tree.map(
|
||||
lambda particles, dx, p: jnp.stack([particles + dx, p], axis=0),
|
||||
particles, dx, p)
|
||||
print(f"y0 structure: {jax.tree.structure(y0)}")
|
||||
|
||||
stepsize_controller = ConstantStepSize()
|
||||
|
@ -107,18 +112,17 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
|
|||
saveat=SaveAt(ts=snapshots),
|
||||
stepsize_controller=stepsize_controller)
|
||||
ode_solutions = [sol[0] for sol in res.ys]
|
||||
|
||||
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), ode_solutions[-1])
|
||||
return particles + dx , ode_field
|
||||
|
||||
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32),
|
||||
ode_solutions[-1])
|
||||
return particles + dx, ode_field
|
||||
|
||||
ode_field = cic_paint_dx(ode_solutions[-1])
|
||||
return dx , ode_field
|
||||
return dx, ode_field
|
||||
|
||||
|
||||
|
||||
lpt_particles , ode_field = forward_model(params , mesh_shape,box_size,halo_size , snapshots)
|
||||
|
||||
lpt_particles, ode_field = forward_model(params, mesh_shape, box_size,
|
||||
halo_size, snapshots)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
@ -127,11 +131,11 @@ lpt_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), lpt_particles)
|
|||
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.subplot(121)
|
||||
plt.imshow(lpt_field.sum(axis=0) , cmap='magma')
|
||||
plt.imshow(lpt_field.sum(axis=0), cmap='magma')
|
||||
plt.colorbar()
|
||||
plt.title('LPT field')
|
||||
plt.subplot(122)
|
||||
plt.imshow(ode_field.sum(axis=0) , cmap='magma')
|
||||
plt.imshow(ode_field.sum(axis=0), cmap='magma')
|
||||
plt.colorbar()
|
||||
plt.title('ODE field')
|
||||
plt.show()
|
||||
|
@ -144,4 +148,4 @@ plt.close()
|
|||
#field = ShardedArray(field, sharding)
|
||||
#
|
||||
#
|
||||
#cic_read_dx(field , particles )
|
||||
#cic_read_dx(field , particles )
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue