This commit is contained in:
Wassim Kabalan 2025-01-20 22:41:19 +01:00
parent 20fe25c324
commit 1f5c619531
10 changed files with 290 additions and 210 deletions

View file

@ -79,6 +79,7 @@ def slice_unpad_impl(x, pad_width):
return x[tuple(unpad_slice)]
def slice_pad_impl(x, pad_width):
return jax.tree.map(lambda x: jnp.pad(x, pad_width), x)

View file

@ -1,9 +1,10 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.lax import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs
import jax
from jaxpm.distributed import autoshmap
@ -25,7 +26,8 @@ def fftk(k_array):
def interpolate_power_spectrum(input, k, pk, sharding=None):
def pk_fn(input):
return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
return jax.tree.map(
lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
@ -61,7 +63,8 @@ def gradient_kernel(kvec, direction, order=1):
return wts
else:
w = kvec[direction]
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w)
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)),
w)
wts = a * 1j
return wts
@ -85,11 +88,14 @@ def invlaplace_kernel(kvec, fd=False):
Complex kernel values
"""
if fd:
kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec)
kk = sum(
jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki)
for ki in kvec)
else:
kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec)
kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk)
return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk)
return jax.tree.map(lambda x, y: -jnp.where(y == 0, 0, 1 / x), kk_nozeros,
kk)
def longrange_kernel(kvec, r_split):
@ -131,7 +137,10 @@ def cic_compensation(kvec):
wts: array
Complex kernel values
"""
kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)]
kwts = [
jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i])
for i in range(3)
]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts

View file

@ -19,28 +19,36 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
"""
positions = positions.reshape([-1, 3])
positions = jax.tree.map(lambda p : jnp.expand_dims(p , 1) , positions)
floor = jax.tree.map(jnp.floor , positions)
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
floor = jax.tree.map(jnp.floor, positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jax.tree.map(jnp.abs , (positions - neighboor_coords))
kernel = 1. - jax.tree.map(jnp.abs, (positions - neighboor_coords))
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
kernel = jax.tree.map(lambda k , w : jnp.multiply(jnp.expand_dims(w, axis=-1)
, k) , kernel , weight)
kernel = jax.tree.map(
lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k),
kernel, weight)
else:
kernel = jax.tree.map(lambda k , w : jnp.multiply(w.reshape(*positions.shape[:-1]) , k) , kernel , weight)
kernel = jax.tree.map(
lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k),
kernel, weight)
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)) , neighboor_coords)
neighboor_coords = jax.tree.map(
lambda nc: jnp.mod(
nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)
), neighboor_coords)
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = jax.tree.map(lambda g , nc , k : lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums) , grid_mesh , neighboor_coords , kernel)
mesh = jax.tree.map(
lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums),
grid_mesh, neighboor_coords, kernel)
return mesh
@ -49,7 +57,8 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
positions_structure = jax.tree.structure(positions)
grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh))
grid_mesh = jax.tree.unflatten(positions_structure,
jax.tree.leaves(grid_mesh))
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding)
@ -79,24 +88,27 @@ def _cic_read_impl(grid_mesh, positions):
# Reshape positions to a flat list of 3D coordinates
positions = positions.reshape([-1, 3])
# Expand dimensions to calculate neighbor coordinates
positions = jax.tree.map(lambda p : jnp.expand_dims(p, 1) , positions)
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
# Floor the positions to get the base grid cell for each particle
floor = jax.tree.map(jnp.floor , positions)
floor = jax.tree.map(jnp.floor, positions)
# Define connections to calculate all neighbor coordinates
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
# Calculate the 8 neighboring coordinates
neighboor_coords = floor + connection
# Calculate kernel weights based on distance from each neighboring coordinate
kernel = 1. - jax.tree.map(jnp.abs , positions - neighboor_coords)
kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
# Modulo operation to wrap around edges if necessary
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.astype('int32')
,jnp.array(grid_mesh.shape)) , neighboor_coords)
neighboor_coords = jax.tree.map(
lambda nc: jnp.mod(nc.astype('int32'), jnp.array(grid_mesh.shape)),
neighboor_coords)
# Ensure grid_mesh shape is as expected
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
grid_mesh = jax.tree.map(lambda g , nc , k : g[nc[...,0], nc[...,1], nc[...,2]] * k , grid_mesh , neighboor_coords , kernel)
grid_mesh = jax.tree.map(
lambda g, nc, k: g[nc[..., 0], nc[..., 1], nc[..., 2]] * k, grid_mesh,
neighboor_coords, kernel)
return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
@ -157,7 +169,9 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jax.tree.map(lambda x : jnp.zeros(x.shape[:-1], dtype=displacements.dtype), displacements)
particle_mesh = jax.tree.map(
lambda x: jnp.zeros(x.shape[:-1], dtype=displacements.dtype),
displacements)
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
@ -165,13 +179,18 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
weight = weight.flatten()
# Padding is forced to be zero in a single gpu run
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
a, b, c = jax.tree.map(
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
jnp.arange(x.shape[1]),
jnp.arange(x.shape[2]),
indexing='ij') , axis=0), particle_mesh)
indexing='ij'),
axis=0), particle_mesh)
particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh)
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
particle_mesh = jax.tree.map(lambda x: jnp.pad(x, halo_size),
particle_mesh)
pmid = jax.tree.map(
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
c)
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
@ -217,12 +236,16 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij')
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
a, b, c = jax.tree.map(
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij') , axis=0), grid_mesh)
indexing='ij'),
axis=0), grid_mesh)
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
pmid = jax.tree.map(
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
c)
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])

View file

@ -28,8 +28,8 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_cell_size, new_shape):
"""Multilinear enmeshing."""
base_indices = jax.tree.map(jnp.asarray , base_indices)
displacements = jax.tree.map(jnp.asarray , displacements)
base_indices = jax.tree.map(jnp.asarray, base_indices)
displacements = jax.tree.map(jnp.asarray, displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
@ -61,8 +61,8 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_displacements = particle_positions - new_indices * new_cell_size
if base_shape is not None:
new_displacements -= jax.tree.map(jnp.rint ,
new_displacements / grid_length
new_displacements -= jax.tree.map(
jnp.rint, new_displacements / grid_length
) * grid_length # also abs(new_displacements) < new_cell_size is expected
new_indices = new_indices.astype(base_indices.dtype)
@ -89,7 +89,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
if base_shape is not None:
new_indices %= base_shape
weights = 1 - jax.tree.map(jnp.abs , new_displacements)
weights = 1 - jax.tree.map(jnp.abs, new_displacements)
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
@ -109,11 +109,15 @@ def _scatter_chunk(carry, chunk):
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# scatter
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
ind)
mesh_structure = jax.tree.structure(mesh)
val_flat = jax.tree.leaves(val)
val_tree = jax.tree.unflatten(mesh_structure, val_flat)
mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac)
mesh = jax.tree.map(
lambda m, v, i, f: m.at[i].add(
jnp.multiply(jnp.expand_dims(v, axis=-1), f)), mesh, val_tree, ind,
frac)
carry = mesh, offset, cell_size, mesh_shape
return carry, None
@ -127,8 +131,8 @@ def scatter(pmid,
cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
val = jax.tree.map(jnp.asarray , val)
mesh = jax.tree.map(jnp.asarray , mesh)
val = jax.tree.map(jnp.asarray, val)
mesh = jax.tree.map(jnp.asarray, mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size, mesh.shape
if remainder is not None:
@ -151,9 +155,9 @@ def _chunk_cat(remainder_array, chunked_array):
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
mesh = jax.tree.map(jnp.asarray , mesh)
mesh = jax.tree.map(jnp.asarray, mesh)
val = jax.tree.map(jnp.asarray , val)
val = jax.tree.map(jnp.asarray, val)
if mesh.shape[spatial_ndim:] != val.shape[1:]:
raise ValueError('channel shape mismatch: '
@ -187,11 +191,15 @@ def _gather_chunk(carry, chunk):
spatial_shape)
# gather
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
ind)
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
ind_structure = jax.tree.structure(ind)
frac_structure = jax.tree.structure(frac)
mesh_structure = jax.tree.structure(mesh)
val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac)
val += jax.tree.map(
lambda m, i, f:
(m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1), mesh, ind,
frac)
return carry, val

View file

@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
import jax_cosmo as jc
@ -7,7 +8,7 @@ from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
invlaplace_kernel, longrange_kernel)
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
import jax
def pm_forces(positions,
mesh_shape=None,
@ -52,10 +53,12 @@ def pm_forces(positions,
kvec, r_split=r_split)
# Computes gravitational forces
forces = [
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
) for i in range(3)]
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)
]
forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2])
forces = jax.tree.map(lambda x, y, z: jnp.stack([x, y, z], axis=-1),
forces[0], forces[1], forces[2])
return forces
@ -73,8 +76,9 @@ def lpt(cosmo,
"""
paint_absolute_pos = particles is not None
if particles is None:
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
shape=(*ic.shape, 3)) , initial_conditions)
particles = jax.tree.map(
lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
initial_conditions)
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -198,7 +202,8 @@ def make_diffrax_ode(mesh_shape,
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel)
return jax.tree.map(lambda dp, dv: jnp.stack([dp, dv], axis=0), dpos,
dvel)
return nbody_ode

View file

@ -1,16 +1,17 @@
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
from scipy.special import legendre
import jax
__all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing'
]
def _initialize_pk(mesh_shape, box_shape, kedges, los):
"""
Parameters
@ -100,11 +101,12 @@ def power_spectrum(mesh,
n_bins = len(kavg) + 2
# FFTs
meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh)
meshk = jax.tree.map(lambda x: jnp.fft.fftn(x, norm='ortho'), mesh)
if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2
else:
mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2)
mmk = meshk * jax.tree.map(
lambda x: jnp.fft.fftn(x, norm='ortho').conj(), mesh2)
# Sum powers
pk = jnp.empty((len(poles), n_bins))

View file

@ -174,18 +174,22 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
return fpm_mesh
def compare_sharding(sharding1, sharding2):
def get_axis_size(sharding, idx):
axis_name = sharding.spec[idx]
if axis_name is None:
return 1
else:
return sharding.mesh.shape[sharding.spec[idx]]
def get_pdims_from_sharding(sharding):
return tuple([get_axis_size(sharding, i) for i in range(len(sharding.spec))])
return tuple(
[get_axis_size(sharding, i) for i in range(len(sharding.spec))])
pdims1 = get_pdims_from_sharding(sharding1)
pdims2 = get_pdims_from_sharding(sharding2)
pdims1 = pdims1 + (1,) * (3 - len(pdims1))
pdims2 = pdims2 + (1,) * (3 - len(pdims2))
pdims1 = pdims1 + (1, ) * (3 - len(pdims1))
pdims2 = pdims2 + (1, ) * (3 - len(pdims2))
return pdims1 == pdims2

View file

@ -1,14 +1,15 @@
import jax
import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE
from jax import numpy as jnp
from jaxdecomp import ShardedArray
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode
from jaxpm.utils import power_spectrum
import jax
_TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3
@ -17,7 +18,8 @@ _PM_TOLERANCE = 1e-3
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
shardedArrayAPI):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
@ -53,7 +55,8 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI):
fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
shardedArrayAPI):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
@ -77,12 +80,13 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
assert type(dx) == ShardedArray
assert type(lpt_field) == ShardedArray
@pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_nbody_absolute(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order , shardedArrayAPI):
cosmo, order, shardedArrayAPI):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
@ -110,7 +114,8 @@ def test_nbody_absolute(simulation_config, initial_conditions,
saveat = SaveAt(t1=True)
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p)
y0 = jax.tree.map(lambda particles, dx, p: jnp.stack([particles + dx, p]),
particles, dx, p)
solutions = diffeqsolve(ode_fn,
solver,
@ -135,7 +140,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type( solutions.ys[-1, 0]) == ShardedArray
assert type(solutions.ys[-1, 0]) == ShardedArray
assert type(final_field) == ShardedArray
@ -144,7 +149,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_nbody_relative(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order , shardedArrayAPI):
cosmo, order, shardedArrayAPI):
mesh_shape, box_shape = simulation_config
cosmo._workspace = {}
@ -155,8 +160,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
solver = Dopri5()
controller = PIDController(rtol=1e-9,
@ -167,7 +171,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
saveat = SaveAt(t1=True)
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]), dx, p)
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solutions = diffeqsolve(ode_fn,
solver,
@ -192,5 +196,5 @@ def test_nbody_relative(simulation_config, initial_conditions,
if shardedArrayAPI:
assert type(dx) == ShardedArray
assert type( solutions.ys[-1, 0]) == ShardedArray
assert type(solutions.ys[-1, 0]) == ShardedArray
assert type(final_field) == ShardedArray

View file

@ -1,9 +1,12 @@
from conftest import initialize_distributed , compare_sharding
from conftest import compare_sharding, initialize_distributed
initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402
import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -12,13 +15,13 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402
from jaxpm.pm import pm_forces # noqa : E402
from jaxpm.distributed import uniform_particles , fft3d # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from jaxdecomp import ShardedArray # noqa : E402
from functools import partial # noqa : E402
import jax_cosmo as jc # noqa : E402
from jaxpm.distributed import fft3d, uniform_particles # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import pm_forces # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃
@ -27,7 +30,7 @@ _TOLERANCE = 3.0 # 🙃🙃
@pytest.mark.parametrize("absolute_painting", [True, False])
@pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting,shardedArrayAPI):
absolute_painting, shardedArrayAPI):
mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN
@ -42,18 +45,16 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
if shardedArrayAPI:
particles = ShardedArray(particles)
# Initial displacement
dx, p, _ = lpt(cosmo,
ic,
particles,
a=0.1,
order=order)
dx, p, _ = lpt(cosmo, ic, particles, a=0.1, order=order)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
dx, p)
else:
dx, p, _ = lpt(cosmo, ic, a=0.1, order=order)
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
paint_absolute_pos=False))
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -87,13 +88,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2
ic = lax.with_sharding_constraint(initial_conditions,
sharding)
ic = lax.with_sharding_constraint(initial_conditions, sharding)
print(f"sharded initial conditions {ic.sharding}")
if shardedArrayAPI:
ic = ShardedArray(ic , sharding)
ic = ShardedArray(ic, sharding)
cosmo._workspace = {}
if absolute_painting:
@ -110,12 +110,13 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(
mesh_shape,
make_diffrax_ode(mesh_shape,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
dx, p)
else:
dx, p, _ = lpt(cosmo,
ic,
@ -124,12 +125,11 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(
mesh_shape,
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -170,17 +170,18 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
if shardedArrayAPI:
assert type(multi_device_final_field) == ShardedArray
assert compare_sharding(multi_device_final_field.sharding , sharding)
assert compare_sharding(multi_device_final_field.initial_sharding , sharding)
assert compare_sharding(multi_device_final_field.sharding, sharding)
assert compare_sharding(multi_device_final_field.initial_sharding,
sharding)
assert mse < _TOLERANCE
@pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, order,nbody_from_lpt1, nbody_from_lpt2,
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2,
absolute_painting):
mesh_shape, box_shape = simulation_config
@ -196,14 +197,12 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
print(f"sharded initial conditions {initial_conditions.sharding}")
initial_conditions = ShardedArray(initial_conditions , sharding)
initial_conditions = ShardedArray(initial_conditions, sharding)
cosmo._workspace = {}
@jax.jit
def forward_model(initial_conditions , cosmo):
def forward_model(initial_conditions, cosmo):
if absolute_painting:
particles = uniform_particles(mesh_shape, sharding=sharding)
@ -218,12 +217,13 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(
mesh_shape,
make_diffrax_ode(mesh_shape,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p)
y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p]),
particles, dx, p)
else:
dx, p, _ = lpt(cosmo,
initial_conditions,
@ -232,12 +232,11 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
halo_size=halo_size,
sharding=sharding)
ode_fn = ODETerm(
make_diffrax_ode(
mesh_shape,
make_diffrax_ode(mesh_shape,
paint_absolute_pos=False,
halo_size=halo_size,
sharding=sharding))
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p)
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5()
controller = PIDController(rtol=1e-8,
@ -271,30 +270,31 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
return multi_device_final_field
@jax.jit
def model(initial_conditions , cosmo):
def model(initial_conditions, cosmo):
final_field = forward_model(initial_conditions , cosmo)
final_field = forward_model(initial_conditions, cosmo)
final_field, = jax.tree.leaves(final_field)
return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions , cosmo)
obs_val = model(initial_conditions, cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(jax.random.key(42) , initial_conditions.shape) * 5
shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions , cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions , cosmo)
good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
assert compare_sharding(good_grads.sharding , initial_conditions.sharding)
assert compare_sharding(off_grads.sharding , initial_conditions.sharding)
assert compare_sharding(good_grads.sharding, initial_conditions.sharding)
assert compare_sharding(off_grads.sharding, initial_conditions.sharding)
@pytest.mark.distributed
@pytest.mark.parametrize("absolute_painting", [True, False])
def test_fwd_rev_gradients(cosmo,absolute_painting):
def test_fwd_rev_gradients(cosmo, absolute_painting):
mesh_shape, box_shape = (8 , 8 , 8) , (20.0 , 20.0 , 20.0)
mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
# SINGLE DEVICE RUN
cosmo._workspace = {}
@ -308,17 +308,23 @@ def test_fwd_rev_gradients(cosmo,absolute_painting):
sharding)
print(f"sharded initial conditions {initial_conditions.sharding}")
initial_conditions = ShardedArray(initial_conditions , sharding)
initial_conditions = ShardedArray(initial_conditions, sharding)
cosmo._workspace = {}
@partial(jax.jit , static_argnums=(3,4 , 5))
def compute_forces(initial_conditions , cosmo , particles=None , a=0.5 , halo_size=0 , sharding=None):
@partial(jax.jit, static_argnums=(3, 4, 5))
def compute_forces(initial_conditions,
cosmo,
particles=None,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = particles is not None
if particles is None:
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
shape=(*ic.shape, 3)) , initial_conditions)
particles = jax.tree.map(
lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
initial_conditions)
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -329,13 +335,27 @@ def test_fwd_rev_gradients(cosmo,absolute_painting):
halo_size=halo_size,
sharding=sharding)
return initial_force[...,0]
return initial_force[..., 0]
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding) , sharding) if absolute_painting else None
forces = compute_forces(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding)
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding),
sharding) if absolute_painting else None
forces = compute_forces(initial_conditions,
cosmo,
particles=particles,
halo_size=halo_size,
sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
cosmo,
particles=particles,
halo_size=halo_size,
sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
cosmo,
particles=particles,
halo_size=halo_size,
sharding=sharding)
assert compare_sharding(forces.sharding , initial_conditions.sharding)
assert compare_sharding(back_gradient[0,0,0,...].sharding , initial_conditions.sharding)
assert compare_sharding(fwd_gradient.sharding , initial_conditions.sharding)
assert compare_sharding(forces.sharding, initial_conditions.sharding)
assert compare_sharding(back_gradient[0, 0, 0, ...].sharding,
initial_conditions.sharding)
assert compare_sharding(fwd_gradient.sharding, initial_conditions.sharding)

View file

@ -1,31 +1,31 @@
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import os
os.environ["EQX_ON_ERROR"] = "nan"
from functools import partial
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
diffeqsolve)
from jax.debug import visualize_array_sharding
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx , cic_read_dx , cic_paint , cic_read
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from functools import partial
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
from jaxpm.distributed import uniform_particles
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import uniform_particles
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
all_gather = partial(process_allgather, tiled=False)
pdims = (2, 4)
@ -34,8 +34,8 @@ pdims = (2, 4)
#sharding = NamedSharding(mesh, P('x', 'y'))
sharding = None
from typing import NamedTuple
from jaxdecomp import ShardedArray
mesh_shape = 64
@ -43,19 +43,21 @@ box_size = 64.
halo_size = 2
snapshots = (0.5, 1.0)
class Params(NamedTuple):
omega_c: float
sigma8: float
initial_conditions : jnp.ndarray
initial_conditions: jnp.ndarray
mesh_shape = (mesh_shape,) * 3
box_size = (box_size,) * 3
mesh_shape = (mesh_shape, ) * 3
box_size = (box_size, ) * 3
omega_c = 0.25
sigma8 = 0.8
# Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power(
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8),
k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape,
@ -64,21 +66,19 @@ initial_conditions = linear_field(mesh_shape,
seed=jax.random.PRNGKey(0),
sharding=sharding)
#initial_conditions = ShardedArray(initial_conditions, sharding)
params = Params(omega_c, sigma8, initial_conditions)
@partial(jax.jit , static_argnums=(1 , 2,3,4 ))
def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def forward_model(params, mesh_shape, box_size, halo_size, snapshots):
# Create initial conditions
cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8)
particles = uniform_particles(mesh_shape , sharding)
particles = uniform_particles(mesh_shape, sharding)
ic_structure = jax.tree.structure(params.initial_conditions)
particles = jax.tree.unflatten(ic_structure , jax.tree.leaves(particles))
particles = jax.tree.unflatten(ic_structure, jax.tree.leaves(particles))
# Initial displacement
dx, p, f = lpt(cosmo,
params.initial_conditions,
@ -90,10 +90,15 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
# Evolve the simulation forward
ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=True,halo_size=halo_size,sharding=sharding))
make_diffrax_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=halo_size,
sharding=sharding))
solver = LeapfrogMidpoint()
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx ,p],axis=0) , particles , dx , p)
y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p], axis=0),
particles, dx, p)
print(f"y0 structure: {jax.tree.structure(y0)}")
stepsize_controller = ConstantStepSize()
@ -108,17 +113,16 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
stepsize_controller=stepsize_controller)
ode_solutions = [sol[0] for sol in res.ys]
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), ode_solutions[-1])
return particles + dx , ode_field
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32),
ode_solutions[-1])
return particles + dx, ode_field
ode_field = cic_paint_dx(ode_solutions[-1])
return dx , ode_field
return dx, ode_field
lpt_particles , ode_field = forward_model(params , mesh_shape,box_size,halo_size , snapshots)
lpt_particles, ode_field = forward_model(params, mesh_shape, box_size,
halo_size, snapshots)
import matplotlib.pyplot as plt
@ -127,11 +131,11 @@ lpt_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), lpt_particles)
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.imshow(lpt_field.sum(axis=0) , cmap='magma')
plt.imshow(lpt_field.sum(axis=0), cmap='magma')
plt.colorbar()
plt.title('LPT field')
plt.subplot(122)
plt.imshow(ode_field.sum(axis=0) , cmap='magma')
plt.imshow(ode_field.sum(axis=0), cmap='magma')
plt.colorbar()
plt.title('ODE field')
plt.show()