This commit is contained in:
Wassim Kabalan 2025-01-20 22:41:19 +01:00
parent 20fe25c324
commit 1f5c619531
10 changed files with 290 additions and 210 deletions

View file

@ -79,6 +79,7 @@ def slice_unpad_impl(x, pad_width):
return x[tuple(unpad_slice)] return x[tuple(unpad_slice)]
def slice_pad_impl(x, pad_width): def slice_pad_impl(x, pad_width):
return jax.tree.map(lambda x: jnp.pad(x, pad_width), x) return jax.tree.map(lambda x: jnp.pad(x, pad_width), x)

View file

@ -1,9 +1,10 @@
import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.lax import FftType from jax.lax import FftType
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs from jaxdecomp import fftfreq3d, get_output_specs
import jax
from jaxpm.distributed import autoshmap from jaxpm.distributed import autoshmap
@ -25,7 +26,8 @@ def fftk(k_array):
def interpolate_power_spectrum(input, k, pk, sharding=None): def interpolate_power_spectrum(input, k, pk, sharding=None):
def pk_fn(input): def pk_fn(input):
return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input) return jax.tree.map(
lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P() specs = sharding.spec if sharding is not None else P()
@ -61,7 +63,8 @@ def gradient_kernel(kvec, direction, order=1):
return wts return wts
else: else:
w = kvec[direction] w = kvec[direction]
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w) a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)),
w)
wts = a * 1j wts = a * 1j
return wts return wts
@ -85,11 +88,14 @@ def invlaplace_kernel(kvec, fd=False):
Complex kernel values Complex kernel values
""" """
if fd: if fd:
kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec) kk = sum(
jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki)
for ki in kvec)
else: else:
kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec) kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec)
kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk) kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk)
return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk) return jax.tree.map(lambda x, y: -jnp.where(y == 0, 0, 1 / x), kk_nozeros,
kk)
def longrange_kernel(kvec, r_split): def longrange_kernel(kvec, r_split):
@ -131,7 +137,10 @@ def cic_compensation(kvec):
wts: array wts: array
Complex kernel values Complex kernel values
""" """
kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)] kwts = [
jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i])
for i in range(3)
]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2) wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts return wts

View file

@ -29,18 +29,26 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None: if weight is not None:
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)): if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
kernel = jax.tree.map(lambda k , w : jnp.multiply(jnp.expand_dims(w, axis=-1) kernel = jax.tree.map(
, k) , kernel , weight) lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k),
kernel, weight)
else: else:
kernel = jax.tree.map(lambda k , w : jnp.multiply(w.reshape(*positions.shape[:-1]) , k) , kernel , weight) kernel = jax.tree.map(
lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k),
kernel, weight)
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)) , neighboor_coords) neighboor_coords = jax.tree.map(
lambda nc: jnp.mod(
nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)
), neighboor_coords)
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2), inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, scatter_dims_to_operand_dims=(0, 1,
2)) 2))
mesh = jax.tree.map(lambda g , nc , k : lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums) , grid_mesh , neighboor_coords , kernel) mesh = jax.tree.map(
lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums),
grid_mesh, neighboor_coords, kernel)
return mesh return mesh
@ -49,7 +57,8 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
positions_structure = jax.tree.structure(positions) positions_structure = jax.tree.structure(positions)
grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh)) grid_mesh = jax.tree.unflatten(positions_structure,
jax.tree.leaves(grid_mesh))
positions = positions.reshape((*grid_mesh.shape, 3)) positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding) halo_size, halo_extents = get_halo_size(halo_size, sharding)
@ -91,12 +100,15 @@ def _cic_read_impl(grid_mesh, positions):
kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords) kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
# Modulo operation to wrap around edges if necessary # Modulo operation to wrap around edges if necessary
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.astype('int32') neighboor_coords = jax.tree.map(
,jnp.array(grid_mesh.shape)) , neighboor_coords) lambda nc: jnp.mod(nc.astype('int32'), jnp.array(grid_mesh.shape)),
neighboor_coords)
# Ensure grid_mesh shape is as expected # Ensure grid_mesh shape is as expected
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel # Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
grid_mesh = jax.tree.map(lambda g , nc , k : g[nc[...,0], nc[...,1], nc[...,2]] * k , grid_mesh , neighboor_coords , kernel) grid_mesh = jax.tree.map(
lambda g, nc, k: g[nc[..., 0], nc[..., 1], nc[..., 2]] * k, grid_mesh,
neighboor_coords, kernel)
return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
@ -157,7 +169,9 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
original_shape = displacements.shape original_shape = displacements.shape
particle_mesh = jax.tree.map(lambda x : jnp.zeros(x.shape[:-1], dtype=displacements.dtype), displacements) particle_mesh = jax.tree.map(
lambda x: jnp.zeros(x.shape[:-1], dtype=displacements.dtype),
displacements)
if not jnp.isscalar(weight): if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]: if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape") raise ValueError("Weight shape must match particle shape")
@ -165,13 +179,18 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
weight = weight.flatten() weight = weight.flatten()
# Padding is forced to be zero in a single gpu run # Padding is forced to be zero in a single gpu run
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]), a, b, c = jax.tree.map(
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
jnp.arange(x.shape[1]), jnp.arange(x.shape[1]),
jnp.arange(x.shape[2]), jnp.arange(x.shape[2]),
indexing='ij') , axis=0), particle_mesh) indexing='ij'),
axis=0), particle_mesh)
particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh) particle_mesh = jax.tree.map(lambda x: jnp.pad(x, halo_size),
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) particle_mesh)
pmid = jax.tree.map(
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
c)
return scatter(pmid.reshape([-1, 3]), return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]), displacements.reshape([-1, 3]),
particle_mesh, particle_mesh,
@ -217,12 +236,16 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
jnp.arange(original_shape[1]), jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]), jnp.arange(original_shape[2]),
indexing='ij') indexing='ij')
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]), a, b, c = jax.tree.map(
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]), jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]), jnp.arange(original_shape[2]),
indexing='ij') , axis=0), grid_mesh) indexing='ij'),
axis=0), grid_mesh)
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) pmid = jax.tree.map(
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
c)
pmid = pmid.reshape([-1, 3]) pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3]) disp = disp.reshape([-1, 3])

View file

@ -61,8 +61,8 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_displacements = particle_positions - new_indices * new_cell_size new_displacements = particle_positions - new_indices * new_cell_size
if base_shape is not None: if base_shape is not None:
new_displacements -= jax.tree.map(jnp.rint , new_displacements -= jax.tree.map(
new_displacements / grid_length jnp.rint, new_displacements / grid_length
) * grid_length # also abs(new_displacements) < new_cell_size is expected ) * grid_length # also abs(new_displacements) < new_cell_size is expected
new_indices = new_indices.astype(base_indices.dtype) new_indices = new_indices.astype(base_indices.dtype)
@ -109,11 +109,15 @@ def _scatter_chunk(carry, chunk):
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape) spatial_shape)
# scatter # scatter
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
ind)
mesh_structure = jax.tree.structure(mesh) mesh_structure = jax.tree.structure(mesh)
val_flat = jax.tree.leaves(val) val_flat = jax.tree.leaves(val)
val_tree = jax.tree.unflatten(mesh_structure, val_flat) val_tree = jax.tree.unflatten(mesh_structure, val_flat)
mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac) mesh = jax.tree.map(
lambda m, v, i, f: m.at[i].add(
jnp.multiply(jnp.expand_dims(v, axis=-1), f)), mesh, val_tree, ind,
frac)
carry = mesh, offset, cell_size, mesh_shape carry = mesh, offset, cell_size, mesh_shape
return carry, None return carry, None
@ -187,11 +191,15 @@ def _gather_chunk(carry, chunk):
spatial_shape) spatial_shape)
# gather # gather
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
ind)
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac) frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
ind_structure = jax.tree.structure(ind) ind_structure = jax.tree.structure(ind)
frac_structure = jax.tree.structure(frac) frac_structure = jax.tree.structure(frac)
mesh_structure = jax.tree.structure(mesh) mesh_structure = jax.tree.structure(mesh)
val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac) val += jax.tree.map(
lambda m, i, f:
(m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1), mesh, ind,
frac)
return carry, val return carry, val

View file

@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
@ -7,7 +8,7 @@ from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
invlaplace_kernel, longrange_kernel) invlaplace_kernel, longrange_kernel)
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
import jax
def pm_forces(positions, def pm_forces(positions,
mesh_shape=None, mesh_shape=None,
@ -52,10 +53,12 @@ def pm_forces(positions,
kvec, r_split=r_split) kvec, r_split=r_split)
# Computes gravitational forces # Computes gravitational forces
forces = [ forces = [
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k), positions)
) for i in range(3)] for i in range(3)
]
forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2]) forces = jax.tree.map(lambda x, y, z: jnp.stack([x, y, z], axis=-1),
forces[0], forces[1], forces[2])
return forces return forces
@ -73,8 +76,9 @@ def lpt(cosmo,
""" """
paint_absolute_pos = particles is not None paint_absolute_pos = particles is not None
if particles is None: if particles is None:
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic, particles = jax.tree.map(
shape=(*ic.shape, 3)) , initial_conditions) lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
initial_conditions)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a)) E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -198,7 +202,8 @@ def make_diffrax_ode(mesh_shape,
# Computes the update of velocity (kick) # Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel) return jax.tree.map(lambda dp, dv: jnp.stack([dp, dv], axis=0), dpos,
dvel)
return nbody_ode return nbody_ode

View file

@ -1,16 +1,17 @@
from functools import partial from functools import partial
import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.scipy.stats import norm from jax.scipy.stats import norm
from scipy.special import legendre from scipy.special import legendre
import jax
__all__ = [ __all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh', 'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing' 'cross_correlation_coefficients', 'gaussian_smoothing'
] ]
def _initialize_pk(mesh_shape, box_shape, kedges, los): def _initialize_pk(mesh_shape, box_shape, kedges, los):
""" """
Parameters Parameters
@ -104,7 +105,8 @@ def power_spectrum(mesh,
if mesh2 is None: if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2 mmk = meshk.real**2 + meshk.imag**2
else: else:
mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2) mmk = meshk * jax.tree.map(
lambda x: jnp.fft.fftn(x, norm='ortho').conj(), mesh2)
# Sum powers # Sum powers
pk = jnp.empty((len(poles), n_bins)) pk = jnp.empty((len(poles), n_bins))

View file

@ -174,15 +174,19 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
return fpm_mesh return fpm_mesh
def compare_sharding(sharding1, sharding2): def compare_sharding(sharding1, sharding2):
def get_axis_size(sharding, idx): def get_axis_size(sharding, idx):
axis_name = sharding.spec[idx] axis_name = sharding.spec[idx]
if axis_name is None: if axis_name is None:
return 1 return 1
else: else:
return sharding.mesh.shape[sharding.spec[idx]] return sharding.mesh.shape[sharding.spec[idx]]
def get_pdims_from_sharding(sharding): def get_pdims_from_sharding(sharding):
return tuple([get_axis_size(sharding, i) for i in range(len(sharding.spec))]) return tuple(
[get_axis_size(sharding, i) for i in range(len(sharding.spec))])
pdims1 = get_pdims_from_sharding(sharding1) pdims1 = get_pdims_from_sharding(sharding1)
pdims2 = get_pdims_from_sharding(sharding2) pdims2 = get_pdims_from_sharding(sharding2)

View file

@ -1,14 +1,15 @@
import jax
import pytest import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE from helpers import MSE, MSRE
from jax import numpy as jnp from jax import numpy as jnp
from jaxdecomp import ShardedArray from jaxdecomp import ShardedArray
from jaxpm.distributed import uniform_particles from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode from jaxpm.pm import lpt, make_diffrax_ode
from jaxpm.utils import power_spectrum from jaxpm.utils import power_spectrum
import jax
_TOLERANCE = 1e-4 _TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3 _PM_TOLERANCE = 1e-3
@ -17,7 +18,8 @@ _PM_TOLERANCE = 1e-3
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False]) @pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI): fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
@ -53,7 +55,8 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False]) @pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI): fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
@ -77,6 +80,7 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
assert type(dx) == ShardedArray assert type(dx) == ShardedArray
assert type(lpt_field) == ShardedArray assert type(lpt_field) == ShardedArray
@pytest.mark.single_device @pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False]) @pytest.mark.parametrize("shardedArrayAPI", [True, False])
@ -110,7 +114,8 @@ def test_nbody_absolute(simulation_config, initial_conditions,
saveat = SaveAt(t1=True) saveat = SaveAt(t1=True)
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p) y0 = jax.tree.map(lambda particles, dx, p: jnp.stack([particles + dx, p]),
particles, dx, p)
solutions = diffeqsolve(ode_fn, solutions = diffeqsolve(ode_fn,
solver, solver,
@ -155,8 +160,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-9, controller = PIDController(rtol=1e-9,

View file

@ -1,9 +1,12 @@
from conftest import initialize_distributed , compare_sharding from conftest import compare_sharding, initialize_distributed
initialize_distributed() # ignore : E402 initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402 import jax # noqa : E402
import jax.numpy as jnp # noqa : E402 import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402 import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402 from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -12,13 +15,13 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402 from jax.sharding import PartitionSpec as P # noqa : E402
from jaxpm.pm import pm_forces # noqa : E402
from jaxpm.distributed import uniform_particles , fft3d # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from jaxdecomp import ShardedArray # noqa : E402 from jaxdecomp import ShardedArray # noqa : E402
from functools import partial # noqa : E402
import jax_cosmo as jc # noqa : E402 from jaxpm.distributed import fft3d, uniform_particles # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import pm_forces # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃 _TOLERANCE = 3.0 # 🙃🙃
@ -42,17 +45,15 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
if shardedArrayAPI: if shardedArrayAPI:
particles = ShardedArray(particles) particles = ShardedArray(particles)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo, ic, particles, a=0.1, order=order)
ic,
particles,
a=0.1,
order=order)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
dx, p)
else: else:
dx, p, _ = lpt(cosmo, ic, a=0.1, order=order) dx, p, _ = lpt(cosmo, ic, a=0.1, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) paint_absolute_pos=False))
y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p) y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5() solver = Dopri5()
@ -87,8 +88,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
ic = lax.with_sharding_constraint(initial_conditions, ic = lax.with_sharding_constraint(initial_conditions, sharding)
sharding)
print(f"sharded initial conditions {ic.sharding}") print(f"sharded initial conditions {ic.sharding}")
@ -110,12 +110,13 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode( make_diffrax_ode(mesh_shape,
mesh_shape,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
dx, p)
else: else:
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
ic, ic,
@ -124,8 +125,7 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode( make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=False, paint_absolute_pos=False,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
@ -171,16 +171,17 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
if shardedArrayAPI: if shardedArrayAPI:
assert type(multi_device_final_field) == ShardedArray assert type(multi_device_final_field) == ShardedArray
assert compare_sharding(multi_device_final_field.sharding, sharding) assert compare_sharding(multi_device_final_field.sharding, sharding)
assert compare_sharding(multi_device_final_field.initial_sharding , sharding) assert compare_sharding(multi_device_final_field.initial_sharding,
sharding)
assert mse < _TOLERANCE assert mse < _TOLERANCE
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, order,nbody_from_lpt1, nbody_from_lpt2, def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2,
absolute_painting): absolute_painting):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
@ -196,7 +197,6 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
print(f"sharded initial conditions {initial_conditions.sharding}") print(f"sharded initial conditions {initial_conditions.sharding}")
initial_conditions = ShardedArray(initial_conditions, sharding) initial_conditions = ShardedArray(initial_conditions, sharding)
cosmo._workspace = {} cosmo._workspace = {}
@ -204,7 +204,6 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
@jax.jit @jax.jit
def forward_model(initial_conditions, cosmo): def forward_model(initial_conditions, cosmo):
if absolute_painting: if absolute_painting:
particles = uniform_particles(mesh_shape, sharding=sharding) particles = uniform_particles(mesh_shape, sharding=sharding)
particles = ShardedArray(particles, sharding) particles = ShardedArray(particles, sharding)
@ -218,12 +217,13 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode( make_diffrax_ode(mesh_shape,
mesh_shape,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p]),
particles, dx, p)
else: else:
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
initial_conditions, initial_conditions,
@ -232,8 +232,7 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode( make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=False, paint_absolute_pos=False,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
@ -281,7 +280,8 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
obs_val = model(initial_conditions, cosmo) obs_val = model(initial_conditions, cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(jax.random.key(42) , initial_conditions.shape) * 5 shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions, cosmo) good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions, cosmo) off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
@ -313,12 +313,18 @@ def test_fwd_rev_gradients(cosmo,absolute_painting):
cosmo._workspace = {} cosmo._workspace = {}
@partial(jax.jit, static_argnums=(3, 4, 5)) @partial(jax.jit, static_argnums=(3, 4, 5))
def compute_forces(initial_conditions , cosmo , particles=None , a=0.5 , halo_size=0 , sharding=None): def compute_forces(initial_conditions,
cosmo,
particles=None,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = particles is not None paint_absolute_pos = particles is not None
if particles is None: if particles is None:
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic, particles = jax.tree.map(
shape=(*ic.shape, 3)) , initial_conditions) lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
initial_conditions)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a)) E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -331,11 +337,25 @@ def test_fwd_rev_gradients(cosmo,absolute_painting):
return initial_force[..., 0] return initial_force[..., 0]
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding) , sharding) if absolute_painting else None particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding),
forces = compute_forces(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) sharding) if absolute_painting else None
back_gradient = jax.jacrev(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) forces = compute_forces(initial_conditions,
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) cosmo,
particles=particles,
halo_size=halo_size,
sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
cosmo,
particles=particles,
halo_size=halo_size,
sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
cosmo,
particles=particles,
halo_size=halo_size,
sharding=sharding)
assert compare_sharding(forces.sharding, initial_conditions.sharding) assert compare_sharding(forces.sharding, initial_conditions.sharding)
assert compare_sharding(back_gradient[0,0,0,...].sharding , initial_conditions.sharding) assert compare_sharding(back_gradient[0, 0, 0, ...].sharding,
initial_conditions.sharding)
assert compare_sharding(fwd_gradient.sharding, initial_conditions.sharding) assert compare_sharding(fwd_gradient.sharding, initial_conditions.sharding)

View file

@ -1,31 +1,31 @@
import os import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu" #os.environ["JAX_PLATFORM_NAME"] = "cpu"
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" #os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import os
os.environ["EQX_ON_ERROR"] = "nan" os.environ["EQX_ON_ERROR"] = "nan"
from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
diffeqsolve)
from jax.debug import visualize_array_sharding from jax.debug import visualize_array_sharding
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx , cic_read_dx , cic_paint , cic_read
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from functools import partial
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
from jaxpm.distributed import uniform_particles
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
from jax.experimental.mesh_utils import create_device_mesh from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.distributed import uniform_particles
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
all_gather = partial(process_allgather, tiled=False) all_gather = partial(process_allgather, tiled=False)
pdims = (2, 4) pdims = (2, 4)
@ -34,8 +34,8 @@ pdims = (2, 4)
#sharding = NamedSharding(mesh, P('x', 'y')) #sharding = NamedSharding(mesh, P('x', 'y'))
sharding = None sharding = None
from typing import NamedTuple from typing import NamedTuple
from jaxdecomp import ShardedArray from jaxdecomp import ShardedArray
mesh_shape = 64 mesh_shape = 64
@ -43,19 +43,21 @@ box_size = 64.
halo_size = 2 halo_size = 2
snapshots = (0.5, 1.0) snapshots = (0.5, 1.0)
class Params(NamedTuple): class Params(NamedTuple):
omega_c: float omega_c: float
sigma8: float sigma8: float
initial_conditions: jnp.ndarray initial_conditions: jnp.ndarray
mesh_shape = (mesh_shape, ) * 3 mesh_shape = (mesh_shape, ) * 3
box_size = (box_size, ) * 3 box_size = (box_size, ) * 3
omega_c = 0.25 omega_c = 0.25
sigma8 = 0.8 sigma8 = 0.8
# Create a small function to generate the matter power spectrum # Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128) k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power( pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8),
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding) pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape, initial_conditions = linear_field(mesh_shape,
@ -64,13 +66,11 @@ initial_conditions = linear_field(mesh_shape,
seed=jax.random.PRNGKey(0), seed=jax.random.PRNGKey(0),
sharding=sharding) sharding=sharding)
#initial_conditions = ShardedArray(initial_conditions, sharding) #initial_conditions = ShardedArray(initial_conditions, sharding)
params = Params(omega_c, sigma8, initial_conditions) params = Params(omega_c, sigma8, initial_conditions)
@partial(jax.jit, static_argnums=(1, 2, 3, 4)) @partial(jax.jit, static_argnums=(1, 2, 3, 4))
def forward_model(params, mesh_shape, box_size, halo_size, snapshots): def forward_model(params, mesh_shape, box_size, halo_size, snapshots):
@ -90,10 +90,15 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
# Evolve the simulation forward # Evolve the simulation forward
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=True,halo_size=halo_size,sharding=sharding)) make_diffrax_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=halo_size,
sharding=sharding))
solver = LeapfrogMidpoint() solver = LeapfrogMidpoint()
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx ,p],axis=0) , particles , dx , p) y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p], axis=0),
particles, dx, p)
print(f"y0 structure: {jax.tree.structure(y0)}") print(f"y0 structure: {jax.tree.structure(y0)}")
stepsize_controller = ConstantStepSize() stepsize_controller = ConstantStepSize()
@ -108,17 +113,16 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
stepsize_controller=stepsize_controller) stepsize_controller=stepsize_controller)
ode_solutions = [sol[0] for sol in res.ys] ode_solutions = [sol[0] for sol in res.ys]
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), ode_solutions[-1]) ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32),
ode_solutions[-1])
return particles + dx, ode_field return particles + dx, ode_field
ode_field = cic_paint_dx(ode_solutions[-1]) ode_field = cic_paint_dx(ode_solutions[-1])
return dx, ode_field return dx, ode_field
lpt_particles, ode_field = forward_model(params, mesh_shape, box_size,
lpt_particles , ode_field = forward_model(params , mesh_shape,box_size,halo_size , snapshots) halo_size, snapshots)
import matplotlib.pyplot as plt import matplotlib.pyplot as plt