This commit is contained in:
Wassim Kabalan 2025-01-20 22:41:19 +01:00
parent 20fe25c324
commit 1f5c619531
10 changed files with 290 additions and 210 deletions

View file

@ -79,6 +79,7 @@ def slice_unpad_impl(x, pad_width):
return x[tuple(unpad_slice)] return x[tuple(unpad_slice)]
def slice_pad_impl(x, pad_width): def slice_pad_impl(x, pad_width):
return jax.tree.map(lambda x: jnp.pad(x, pad_width), x) return jax.tree.map(lambda x: jnp.pad(x, pad_width), x)

View file

@ -1,9 +1,10 @@
import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.lax import FftType from jax.lax import FftType
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs from jaxdecomp import fftfreq3d, get_output_specs
import jax
from jaxpm.distributed import autoshmap from jaxpm.distributed import autoshmap
@ -25,7 +26,8 @@ def fftk(k_array):
def interpolate_power_spectrum(input, k, pk, sharding=None): def interpolate_power_spectrum(input, k, pk, sharding=None):
def pk_fn(input): def pk_fn(input):
return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input) return jax.tree.map(
lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
gpu_mesh = sharding.mesh if sharding is not None else None gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P() specs = sharding.spec if sharding is not None else P()
@ -61,7 +63,8 @@ def gradient_kernel(kvec, direction, order=1):
return wts return wts
else: else:
w = kvec[direction] w = kvec[direction]
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w) a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)),
w)
wts = a * 1j wts = a * 1j
return wts return wts
@ -85,11 +88,14 @@ def invlaplace_kernel(kvec, fd=False):
Complex kernel values Complex kernel values
""" """
if fd: if fd:
kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec) kk = sum(
jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki)
for ki in kvec)
else: else:
kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec) kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec)
kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk) kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk)
return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk) return jax.tree.map(lambda x, y: -jnp.where(y == 0, 0, 1 / x), kk_nozeros,
kk)
def longrange_kernel(kvec, r_split): def longrange_kernel(kvec, r_split):
@ -131,7 +137,10 @@ def cic_compensation(kvec):
wts: array wts: array
Complex kernel values Complex kernel values
""" """
kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)] kwts = [
jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i])
for i in range(3)
]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2) wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts return wts

View file

@ -19,28 +19,36 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
""" """
positions = positions.reshape([-1, 3]) positions = positions.reshape([-1, 3])
positions = jax.tree.map(lambda p : jnp.expand_dims(p , 1) , positions) positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
floor = jax.tree.map(jnp.floor , positions) floor = jax.tree.map(jnp.floor, positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jax.tree.map(jnp.abs , (positions - neighboor_coords)) kernel = 1. - jax.tree.map(jnp.abs, (positions - neighboor_coords))
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None: if weight is not None:
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)): if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
kernel = jax.tree.map(lambda k , w : jnp.multiply(jnp.expand_dims(w, axis=-1) kernel = jax.tree.map(
, k) , kernel , weight) lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k),
kernel, weight)
else: else:
kernel = jax.tree.map(lambda k , w : jnp.multiply(w.reshape(*positions.shape[:-1]) , k) , kernel , weight) kernel = jax.tree.map(
lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k),
kernel, weight)
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)) , neighboor_coords) neighboor_coords = jax.tree.map(
lambda nc: jnp.mod(
nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)
), neighboor_coords)
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2), inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, scatter_dims_to_operand_dims=(0, 1,
2)) 2))
mesh = jax.tree.map(lambda g , nc , k : lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums) , grid_mesh , neighboor_coords , kernel) mesh = jax.tree.map(
lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums),
grid_mesh, neighboor_coords, kernel)
return mesh return mesh
@ -49,7 +57,8 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None): def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
positions_structure = jax.tree.structure(positions) positions_structure = jax.tree.structure(positions)
grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh)) grid_mesh = jax.tree.unflatten(positions_structure,
jax.tree.leaves(grid_mesh))
positions = positions.reshape((*grid_mesh.shape, 3)) positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding) halo_size, halo_extents = get_halo_size(halo_size, sharding)
@ -79,24 +88,27 @@ def _cic_read_impl(grid_mesh, positions):
# Reshape positions to a flat list of 3D coordinates # Reshape positions to a flat list of 3D coordinates
positions = positions.reshape([-1, 3]) positions = positions.reshape([-1, 3])
# Expand dimensions to calculate neighbor coordinates # Expand dimensions to calculate neighbor coordinates
positions = jax.tree.map(lambda p : jnp.expand_dims(p, 1) , positions) positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
# Floor the positions to get the base grid cell for each particle # Floor the positions to get the base grid cell for each particle
floor = jax.tree.map(jnp.floor , positions) floor = jax.tree.map(jnp.floor, positions)
# Define connections to calculate all neighbor coordinates # Define connections to calculate all neighbor coordinates
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]]) [1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
# Calculate the 8 neighboring coordinates # Calculate the 8 neighboring coordinates
neighboor_coords = floor + connection neighboor_coords = floor + connection
# Calculate kernel weights based on distance from each neighboring coordinate # Calculate kernel weights based on distance from each neighboring coordinate
kernel = 1. - jax.tree.map(jnp.abs , positions - neighboor_coords) kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
# Modulo operation to wrap around edges if necessary # Modulo operation to wrap around edges if necessary
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.astype('int32') neighboor_coords = jax.tree.map(
,jnp.array(grid_mesh.shape)) , neighboor_coords) lambda nc: jnp.mod(nc.astype('int32'), jnp.array(grid_mesh.shape)),
neighboor_coords)
# Ensure grid_mesh shape is as expected # Ensure grid_mesh shape is as expected
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel # Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
grid_mesh = jax.tree.map(lambda g , nc , k : g[nc[...,0], nc[...,1], nc[...,2]] * k , grid_mesh , neighboor_coords , kernel) grid_mesh = jax.tree.map(
lambda g, nc, k: g[nc[..., 0], nc[..., 1], nc[..., 2]] * k, grid_mesh,
neighboor_coords, kernel)
return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
@ -157,7 +169,9 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_y, _ = halo_size[1] halo_y, _ = halo_size[1]
original_shape = displacements.shape original_shape = displacements.shape
particle_mesh = jax.tree.map(lambda x : jnp.zeros(x.shape[:-1], dtype=displacements.dtype), displacements) particle_mesh = jax.tree.map(
lambda x: jnp.zeros(x.shape[:-1], dtype=displacements.dtype),
displacements)
if not jnp.isscalar(weight): if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]: if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape") raise ValueError("Weight shape must match particle shape")
@ -165,13 +179,18 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
weight = weight.flatten() weight = weight.flatten()
# Padding is forced to be zero in a single gpu run # Padding is forced to be zero in a single gpu run
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]), a, b, c = jax.tree.map(
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
jnp.arange(x.shape[1]), jnp.arange(x.shape[1]),
jnp.arange(x.shape[2]), jnp.arange(x.shape[2]),
indexing='ij') , axis=0), particle_mesh) indexing='ij'),
axis=0), particle_mesh)
particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh) particle_mesh = jax.tree.map(lambda x: jnp.pad(x, halo_size),
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) particle_mesh)
pmid = jax.tree.map(
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
c)
return scatter(pmid.reshape([-1, 3]), return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]), displacements.reshape([-1, 3]),
particle_mesh, particle_mesh,
@ -217,12 +236,16 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
jnp.arange(original_shape[1]), jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]), jnp.arange(original_shape[2]),
indexing='ij') indexing='ij')
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]), a, b, c = jax.tree.map(
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]), jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]), jnp.arange(original_shape[2]),
indexing='ij') , axis=0), grid_mesh) indexing='ij'),
axis=0), grid_mesh)
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c) pmid = jax.tree.map(
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
c)
pmid = pmid.reshape([-1, 3]) pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3]) disp = disp.reshape([-1, 3])

View file

@ -28,8 +28,8 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
def enmesh(base_indices, displacements, cell_size, base_shape, offset, def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_cell_size, new_shape): new_cell_size, new_shape):
"""Multilinear enmeshing.""" """Multilinear enmeshing."""
base_indices = jax.tree.map(jnp.asarray , base_indices) base_indices = jax.tree.map(jnp.asarray, base_indices)
displacements = jax.tree.map(jnp.asarray , displacements) displacements = jax.tree.map(jnp.asarray, displacements)
with jax.experimental.enable_x64(): with jax.experimental.enable_x64():
cell_size = jnp.float64( cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array( cell_size) if new_cell_size is not None else jnp.array(
@ -61,8 +61,8 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_displacements = particle_positions - new_indices * new_cell_size new_displacements = particle_positions - new_indices * new_cell_size
if base_shape is not None: if base_shape is not None:
new_displacements -= jax.tree.map(jnp.rint , new_displacements -= jax.tree.map(
new_displacements / grid_length jnp.rint, new_displacements / grid_length
) * grid_length # also abs(new_displacements) < new_cell_size is expected ) * grid_length # also abs(new_displacements) < new_cell_size is expected
new_indices = new_indices.astype(base_indices.dtype) new_indices = new_indices.astype(base_indices.dtype)
@ -89,7 +89,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
if base_shape is not None: if base_shape is not None:
new_indices %= base_shape new_indices %= base_shape
weights = 1 - jax.tree.map(jnp.abs , new_displacements) weights = 1 - jax.tree.map(jnp.abs, new_displacements)
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
new_indices = jnp.where(new_indices < 0, new_shape, new_indices) new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
@ -109,11 +109,15 @@ def _scatter_chunk(carry, chunk):
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size, ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape) spatial_shape)
# scatter # scatter
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
ind)
mesh_structure = jax.tree.structure(mesh) mesh_structure = jax.tree.structure(mesh)
val_flat = jax.tree.leaves(val) val_flat = jax.tree.leaves(val)
val_tree = jax.tree.unflatten(mesh_structure, val_flat) val_tree = jax.tree.unflatten(mesh_structure, val_flat)
mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac) mesh = jax.tree.map(
lambda m, v, i, f: m.at[i].add(
jnp.multiply(jnp.expand_dims(v, axis=-1), f)), mesh, val_tree, ind,
frac)
carry = mesh, offset, cell_size, mesh_shape carry = mesh, offset, cell_size, mesh_shape
return carry, None return carry, None
@ -127,8 +131,8 @@ def scatter(pmid,
cell_size=1.): cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape ptcl_num, spatial_ndim = pmid.shape
val = jax.tree.map(jnp.asarray , val) val = jax.tree.map(jnp.asarray, val)
mesh = jax.tree.map(jnp.asarray , mesh) mesh = jax.tree.map(jnp.asarray, mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val) remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size, mesh.shape carry = mesh, offset, cell_size, mesh.shape
if remainder is not None: if remainder is not None:
@ -151,9 +155,9 @@ def _chunk_cat(remainder_array, chunked_array):
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.): def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape ptcl_num, spatial_ndim = pmid.shape
mesh = jax.tree.map(jnp.asarray , mesh) mesh = jax.tree.map(jnp.asarray, mesh)
val = jax.tree.map(jnp.asarray , val) val = jax.tree.map(jnp.asarray, val)
if mesh.shape[spatial_ndim:] != val.shape[1:]: if mesh.shape[spatial_ndim:] != val.shape[1:]:
raise ValueError('channel shape mismatch: ' raise ValueError('channel shape mismatch: '
@ -187,11 +191,15 @@ def _gather_chunk(carry, chunk):
spatial_shape) spatial_shape)
# gather # gather
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind) ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
ind)
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac) frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
ind_structure = jax.tree.structure(ind) ind_structure = jax.tree.structure(ind)
frac_structure = jax.tree.structure(frac) frac_structure = jax.tree.structure(frac)
mesh_structure = jax.tree.structure(mesh) mesh_structure = jax.tree.structure(mesh)
val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac) val += jax.tree.map(
lambda m, i, f:
(m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1), mesh, ind,
frac)
return carry, val return carry, val

View file

@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
@ -7,7 +8,7 @@ from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
invlaplace_kernel, longrange_kernel) invlaplace_kernel, longrange_kernel)
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
import jax
def pm_forces(positions, def pm_forces(positions,
mesh_shape=None, mesh_shape=None,
@ -52,10 +53,12 @@ def pm_forces(positions,
kvec, r_split=r_split) kvec, r_split=r_split)
# Computes gravitational forces # Computes gravitational forces
forces = [ forces = [
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k), positions)
) for i in range(3)] for i in range(3)
]
forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2]) forces = jax.tree.map(lambda x, y, z: jnp.stack([x, y, z], axis=-1),
forces[0], forces[1], forces[2])
return forces return forces
@ -73,8 +76,9 @@ def lpt(cosmo,
""" """
paint_absolute_pos = particles is not None paint_absolute_pos = particles is not None
if particles is None: if particles is None:
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic, particles = jax.tree.map(
shape=(*ic.shape, 3)) , initial_conditions) lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
initial_conditions)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a)) E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -198,7 +202,8 @@ def make_diffrax_ode(mesh_shape,
# Computes the update of velocity (kick) # Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel) return jax.tree.map(lambda dp, dv: jnp.stack([dp, dv], axis=0), dpos,
dvel)
return nbody_ode return nbody_ode

View file

@ -1,16 +1,17 @@
from functools import partial from functools import partial
import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from jax.scipy.stats import norm from jax.scipy.stats import norm
from scipy.special import legendre from scipy.special import legendre
import jax
__all__ = [ __all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh', 'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing' 'cross_correlation_coefficients', 'gaussian_smoothing'
] ]
def _initialize_pk(mesh_shape, box_shape, kedges, los): def _initialize_pk(mesh_shape, box_shape, kedges, los):
""" """
Parameters Parameters
@ -100,11 +101,12 @@ def power_spectrum(mesh,
n_bins = len(kavg) + 2 n_bins = len(kavg) + 2
# FFTs # FFTs
meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh) meshk = jax.tree.map(lambda x: jnp.fft.fftn(x, norm='ortho'), mesh)
if mesh2 is None: if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2 mmk = meshk.real**2 + meshk.imag**2
else: else:
mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2) mmk = meshk * jax.tree.map(
lambda x: jnp.fft.fftn(x, norm='ortho').conj(), mesh2)
# Sum powers # Sum powers
pk = jnp.empty((len(poles), n_bins)) pk = jnp.empty((len(poles), n_bins))

View file

@ -174,18 +174,22 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
return fpm_mesh return fpm_mesh
def compare_sharding(sharding1, sharding2): def compare_sharding(sharding1, sharding2):
def get_axis_size(sharding, idx): def get_axis_size(sharding, idx):
axis_name = sharding.spec[idx] axis_name = sharding.spec[idx]
if axis_name is None: if axis_name is None:
return 1 return 1
else: else:
return sharding.mesh.shape[sharding.spec[idx]] return sharding.mesh.shape[sharding.spec[idx]]
def get_pdims_from_sharding(sharding): def get_pdims_from_sharding(sharding):
return tuple([get_axis_size(sharding, i) for i in range(len(sharding.spec))]) return tuple(
[get_axis_size(sharding, i) for i in range(len(sharding.spec))])
pdims1 = get_pdims_from_sharding(sharding1) pdims1 = get_pdims_from_sharding(sharding1)
pdims2 = get_pdims_from_sharding(sharding2) pdims2 = get_pdims_from_sharding(sharding2)
pdims1 = pdims1 + (1,) * (3 - len(pdims1)) pdims1 = pdims1 + (1, ) * (3 - len(pdims1))
pdims2 = pdims2 + (1,) * (3 - len(pdims2)) pdims2 = pdims2 + (1, ) * (3 - len(pdims2))
return pdims1 == pdims2 return pdims1 == pdims2

View file

@ -1,14 +1,15 @@
import jax
import pytest import pytest
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
from helpers import MSE, MSRE from helpers import MSE, MSRE
from jax import numpy as jnp from jax import numpy as jnp
from jaxdecomp import ShardedArray from jaxdecomp import ShardedArray
from jaxpm.distributed import uniform_particles from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_dx from jaxpm.painting import cic_paint, cic_paint_dx
from jaxpm.pm import lpt, make_diffrax_ode from jaxpm.pm import lpt, make_diffrax_ode
from jaxpm.utils import power_spectrum from jaxpm.utils import power_spectrum
import jax
_TOLERANCE = 1e-4 _TOLERANCE = 1e-4
_PM_TOLERANCE = 1e-3 _PM_TOLERANCE = 1e-3
@ -17,7 +18,8 @@ _PM_TOLERANCE = 1e-3
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False]) @pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor, def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI): fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
@ -53,7 +55,8 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False]) @pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor, def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
fpm_lpt1_field, fpm_lpt2_field, cosmo, order , shardedArrayAPI): fpm_lpt1_field, fpm_lpt2_field, cosmo, order,
shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
@ -77,12 +80,13 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
assert type(dx) == ShardedArray assert type(dx) == ShardedArray
assert type(lpt_field) == ShardedArray assert type(lpt_field) == ShardedArray
@pytest.mark.single_device @pytest.mark.single_device
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("shardedArrayAPI", [True, False]) @pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_nbody_absolute(simulation_config, initial_conditions, def test_nbody_absolute(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order , shardedArrayAPI): cosmo, order, shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
@ -110,7 +114,8 @@ def test_nbody_absolute(simulation_config, initial_conditions,
saveat = SaveAt(t1=True) saveat = SaveAt(t1=True)
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]), particles , dx, p) y0 = jax.tree.map(lambda particles, dx, p: jnp.stack([particles + dx, p]),
particles, dx, p)
solutions = diffeqsolve(ode_fn, solutions = diffeqsolve(ode_fn,
solver, solver,
@ -135,7 +140,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
if shardedArrayAPI: if shardedArrayAPI:
assert type(dx) == ShardedArray assert type(dx) == ShardedArray
assert type( solutions.ys[-1, 0]) == ShardedArray assert type(solutions.ys[-1, 0]) == ShardedArray
assert type(final_field) == ShardedArray assert type(final_field) == ShardedArray
@ -144,7 +149,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
@pytest.mark.parametrize("shardedArrayAPI", [True, False]) @pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_nbody_relative(simulation_config, initial_conditions, def test_nbody_relative(simulation_config, initial_conditions,
lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2, lpt_scale_factor, nbody_from_lpt1, nbody_from_lpt2,
cosmo, order , shardedArrayAPI): cosmo, order, shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
cosmo._workspace = {} cosmo._workspace = {}
@ -155,8 +160,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order) dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-9, controller = PIDController(rtol=1e-9,
@ -167,7 +171,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
saveat = SaveAt(t1=True) saveat = SaveAt(t1=True)
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]), dx, p) y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solutions = diffeqsolve(ode_fn, solutions = diffeqsolve(ode_fn,
solver, solver,
@ -192,5 +196,5 @@ def test_nbody_relative(simulation_config, initial_conditions,
if shardedArrayAPI: if shardedArrayAPI:
assert type(dx) == ShardedArray assert type(dx) == ShardedArray
assert type( solutions.ys[-1, 0]) == ShardedArray assert type(solutions.ys[-1, 0]) == ShardedArray
assert type(final_field) == ShardedArray assert type(final_field) == ShardedArray

View file

@ -1,9 +1,12 @@
from conftest import initialize_distributed , compare_sharding from conftest import compare_sharding, initialize_distributed
initialize_distributed() # ignore : E402 initialize_distributed() # ignore : E402
from functools import partial # noqa : E402
import jax # noqa : E402 import jax # noqa : E402
import jax.numpy as jnp # noqa : E402 import jax.numpy as jnp # noqa : E402
import jax_cosmo as jc # noqa : E402
import pytest # noqa : E402 import pytest # noqa : E402
from diffrax import SaveAt # noqa : E402 from diffrax import SaveAt # noqa : E402
from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve from diffrax import Dopri5, ODETerm, PIDController, diffeqsolve
@ -12,13 +15,13 @@ from jax import lax # noqa : E402
from jax.experimental.multihost_utils import process_allgather # noqa : E402 from jax.experimental.multihost_utils import process_allgather # noqa : E402
from jax.sharding import NamedSharding from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P # noqa : E402 from jax.sharding import PartitionSpec as P # noqa : E402
from jaxpm.pm import pm_forces # noqa : E402
from jaxpm.distributed import uniform_particles , fft3d # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
from jaxdecomp import ShardedArray # noqa : E402 from jaxdecomp import ShardedArray # noqa : E402
from functools import partial # noqa : E402
import jax_cosmo as jc # noqa : E402 from jaxpm.distributed import fft3d, uniform_particles # noqa : E402
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import pm_forces # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode # noqa : E402
_TOLERANCE = 3.0 # 🙃🙃 _TOLERANCE = 3.0 # 🙃🙃
@ -27,7 +30,7 @@ _TOLERANCE = 3.0 # 🙃🙃
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
@pytest.mark.parametrize("shardedArrayAPI", [True, False]) @pytest.mark.parametrize("shardedArrayAPI", [True, False])
def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order, def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
absolute_painting,shardedArrayAPI): absolute_painting, shardedArrayAPI):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
# SINGLE DEVICE RUN # SINGLE DEVICE RUN
@ -42,18 +45,16 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
if shardedArrayAPI: if shardedArrayAPI:
particles = ShardedArray(particles) particles = ShardedArray(particles)
# Initial displacement # Initial displacement
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo, ic, particles, a=0.1, order=order)
ic,
particles,
a=0.1,
order=order)
ode_fn = ODETerm(make_diffrax_ode(mesh_shape)) ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
dx, p)
else: else:
dx, p, _ = lpt(cosmo, ic, a=0.1, order=order) dx, p, _ = lpt(cosmo, ic, a=0.1, order=order)
ode_fn = ODETerm( ode_fn = ODETerm(make_diffrax_ode(mesh_shape,
make_diffrax_ode(mesh_shape, paint_absolute_pos=False)) paint_absolute_pos=False))
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p) y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -87,13 +88,12 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding = NamedSharding(mesh, P('x', 'y')) sharding = NamedSharding(mesh, P('x', 'y'))
halo_size = mesh_shape[0] // 2 halo_size = mesh_shape[0] // 2
ic = lax.with_sharding_constraint(initial_conditions, ic = lax.with_sharding_constraint(initial_conditions, sharding)
sharding)
print(f"sharded initial conditions {ic.sharding}") print(f"sharded initial conditions {ic.sharding}")
if shardedArrayAPI: if shardedArrayAPI:
ic = ShardedArray(ic , sharding) ic = ShardedArray(ic, sharding)
cosmo._workspace = {} cosmo._workspace = {}
if absolute_painting: if absolute_painting:
@ -110,12 +110,13 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode( make_diffrax_ode(mesh_shape,
mesh_shape,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p]), particles,
dx, p)
else: else:
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
ic, ic,
@ -124,12 +125,11 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode( make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=False, paint_absolute_pos=False,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p) y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -170,17 +170,18 @@ def test_distrubted_pm(simulation_config, initial_conditions, cosmo, order,
if shardedArrayAPI: if shardedArrayAPI:
assert type(multi_device_final_field) == ShardedArray assert type(multi_device_final_field) == ShardedArray
assert compare_sharding(multi_device_final_field.sharding , sharding) assert compare_sharding(multi_device_final_field.sharding, sharding)
assert compare_sharding(multi_device_final_field.initial_sharding , sharding) assert compare_sharding(multi_device_final_field.initial_sharding,
sharding)
assert mse < _TOLERANCE assert mse < _TOLERANCE
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("order", [1, 2]) @pytest.mark.parametrize("order", [1, 2])
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, order,nbody_from_lpt1, nbody_from_lpt2, def test_distrubted_gradients(simulation_config, initial_conditions, cosmo,
order, nbody_from_lpt1, nbody_from_lpt2,
absolute_painting): absolute_painting):
mesh_shape, box_shape = simulation_config mesh_shape, box_shape = simulation_config
@ -196,14 +197,12 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
print(f"sharded initial conditions {initial_conditions.sharding}") print(f"sharded initial conditions {initial_conditions.sharding}")
initial_conditions = ShardedArray(initial_conditions, sharding)
initial_conditions = ShardedArray(initial_conditions , sharding)
cosmo._workspace = {} cosmo._workspace = {}
@jax.jit @jax.jit
def forward_model(initial_conditions , cosmo): def forward_model(initial_conditions, cosmo):
if absolute_painting: if absolute_painting:
particles = uniform_particles(mesh_shape, sharding=sharding) particles = uniform_particles(mesh_shape, sharding=sharding)
@ -218,12 +217,13 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode( make_diffrax_ode(mesh_shape,
mesh_shape,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx, p]) , particles , dx , p) y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p]),
particles, dx, p)
else: else:
dx, p, _ = lpt(cosmo, dx, p, _ = lpt(cosmo,
initial_conditions, initial_conditions,
@ -232,12 +232,11 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode( make_diffrax_ode(mesh_shape,
mesh_shape,
paint_absolute_pos=False, paint_absolute_pos=False,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding)) sharding=sharding))
y0 = jax.tree.map(lambda dx , p : jnp.stack([dx, p]) , dx , p) y0 = jax.tree.map(lambda dx, p: jnp.stack([dx, p]), dx, p)
solver = Dopri5() solver = Dopri5()
controller = PIDController(rtol=1e-8, controller = PIDController(rtol=1e-8,
@ -271,30 +270,31 @@ def test_distrubted_gradients(simulation_config, initial_conditions, cosmo, orde
return multi_device_final_field return multi_device_final_field
@jax.jit @jax.jit
def model(initial_conditions , cosmo): def model(initial_conditions, cosmo):
final_field = forward_model(initial_conditions , cosmo) final_field = forward_model(initial_conditions, cosmo)
final_field, = jax.tree.leaves(final_field) final_field, = jax.tree.leaves(final_field)
return MSE(final_field, return MSE(final_field,
nbody_from_lpt1 if order == 1 else nbody_from_lpt2) nbody_from_lpt1 if order == 1 else nbody_from_lpt2)
obs_val = model(initial_conditions , cosmo) obs_val = model(initial_conditions, cosmo)
shifted_initial_conditions = initial_conditions + jax.random.normal(jax.random.key(42) , initial_conditions.shape) * 5 shifted_initial_conditions = initial_conditions + jax.random.normal(
jax.random.key(42), initial_conditions.shape) * 5
good_grads = jax.grad(model)(initial_conditions , cosmo) good_grads = jax.grad(model)(initial_conditions, cosmo)
off_grads = jax.grad(model)(shifted_initial_conditions , cosmo) off_grads = jax.grad(model)(shifted_initial_conditions, cosmo)
assert compare_sharding(good_grads.sharding , initial_conditions.sharding) assert compare_sharding(good_grads.sharding, initial_conditions.sharding)
assert compare_sharding(off_grads.sharding , initial_conditions.sharding) assert compare_sharding(off_grads.sharding, initial_conditions.sharding)
@pytest.mark.distributed @pytest.mark.distributed
@pytest.mark.parametrize("absolute_painting", [True, False]) @pytest.mark.parametrize("absolute_painting", [True, False])
def test_fwd_rev_gradients(cosmo,absolute_painting): def test_fwd_rev_gradients(cosmo, absolute_painting):
mesh_shape, box_shape = (8 , 8 , 8) , (20.0 , 20.0 , 20.0) mesh_shape, box_shape = (8, 8, 8), (20.0, 20.0, 20.0)
# SINGLE DEVICE RUN # SINGLE DEVICE RUN
cosmo._workspace = {} cosmo._workspace = {}
@ -308,17 +308,23 @@ def test_fwd_rev_gradients(cosmo,absolute_painting):
sharding) sharding)
print(f"sharded initial conditions {initial_conditions.sharding}") print(f"sharded initial conditions {initial_conditions.sharding}")
initial_conditions = ShardedArray(initial_conditions , sharding) initial_conditions = ShardedArray(initial_conditions, sharding)
cosmo._workspace = {} cosmo._workspace = {}
@partial(jax.jit , static_argnums=(3,4 , 5)) @partial(jax.jit, static_argnums=(3, 4, 5))
def compute_forces(initial_conditions , cosmo , particles=None , a=0.5 , halo_size=0 , sharding=None): def compute_forces(initial_conditions,
cosmo,
particles=None,
a=0.5,
halo_size=0,
sharding=None):
paint_absolute_pos = particles is not None paint_absolute_pos = particles is not None
if particles is None: if particles is None:
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic, particles = jax.tree.map(
shape=(*ic.shape, 3)) , initial_conditions) lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
initial_conditions)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a)) E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -329,13 +335,27 @@ def test_fwd_rev_gradients(cosmo,absolute_painting):
halo_size=halo_size, halo_size=halo_size,
sharding=sharding) sharding=sharding)
return initial_force[...,0] return initial_force[..., 0]
particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding) , sharding) if absolute_painting else None particles = ShardedArray(uniform_particles(mesh_shape, sharding=sharding),
forces = compute_forces(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) sharding) if absolute_painting else None
back_gradient = jax.jacrev(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) forces = compute_forces(initial_conditions,
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions , cosmo , particles=particles,halo_size=halo_size , sharding=sharding) cosmo,
particles=particles,
halo_size=halo_size,
sharding=sharding)
back_gradient = jax.jacrev(compute_forces)(initial_conditions,
cosmo,
particles=particles,
halo_size=halo_size,
sharding=sharding)
fwd_gradient = jax.jacfwd(compute_forces)(initial_conditions,
cosmo,
particles=particles,
halo_size=halo_size,
sharding=sharding)
assert compare_sharding(forces.sharding , initial_conditions.sharding) assert compare_sharding(forces.sharding, initial_conditions.sharding)
assert compare_sharding(back_gradient[0,0,0,...].sharding , initial_conditions.sharding) assert compare_sharding(back_gradient[0, 0, 0, ...].sharding,
assert compare_sharding(fwd_gradient.sharding , initial_conditions.sharding) initial_conditions.sharding)
assert compare_sharding(fwd_gradient.sharding, initial_conditions.sharding)

View file

@ -1,31 +1,31 @@
import os import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu" #os.environ["JAX_PLATFORM_NAME"] = "cpu"
#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" #os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import os
os.environ["EQX_ON_ERROR"] = "nan" os.environ["EQX_ON_ERROR"] = "nan"
from functools import partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt,
diffeqsolve)
from jax.debug import visualize_array_sharding from jax.debug import visualize_array_sharding
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint_dx , cic_read_dx , cic_paint , cic_read
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
from functools import partial
from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve
from jaxpm.distributed import uniform_particles
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
from jax.experimental.mesh_utils import create_device_mesh from jax.experimental.mesh_utils import create_device_mesh
from jax.experimental.multihost_utils import process_allgather from jax.experimental.multihost_utils import process_allgather
from jax.sharding import Mesh, NamedSharding from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P from jax.sharding import PartitionSpec as P
from jaxpm.distributed import uniform_particles
from jaxpm.kernels import interpolate_power_spectrum
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
from jaxpm.pm import linear_field, lpt, make_diffrax_ode
#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices"
all_gather = partial(process_allgather, tiled=False) all_gather = partial(process_allgather, tiled=False)
pdims = (2, 4) pdims = (2, 4)
@ -34,8 +34,8 @@ pdims = (2, 4)
#sharding = NamedSharding(mesh, P('x', 'y')) #sharding = NamedSharding(mesh, P('x', 'y'))
sharding = None sharding = None
from typing import NamedTuple from typing import NamedTuple
from jaxdecomp import ShardedArray from jaxdecomp import ShardedArray
mesh_shape = 64 mesh_shape = 64
@ -43,19 +43,21 @@ box_size = 64.
halo_size = 2 halo_size = 2
snapshots = (0.5, 1.0) snapshots = (0.5, 1.0)
class Params(NamedTuple): class Params(NamedTuple):
omega_c: float omega_c: float
sigma8: float sigma8: float
initial_conditions : jnp.ndarray initial_conditions: jnp.ndarray
mesh_shape = (mesh_shape,) * 3
box_size = (box_size,) * 3 mesh_shape = (mesh_shape, ) * 3
box_size = (box_size, ) * 3
omega_c = 0.25 omega_c = 0.25
sigma8 = 0.8 sigma8 = 0.8
# Create a small function to generate the matter power spectrum # Create a small function to generate the matter power spectrum
k = jnp.logspace(-4, 1, 128) k = jnp.logspace(-4, 1, 128)
pk = jc.power.linear_matter_power( pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8),
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k) k)
pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding) pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding)
initial_conditions = linear_field(mesh_shape, initial_conditions = linear_field(mesh_shape,
@ -64,21 +66,19 @@ initial_conditions = linear_field(mesh_shape,
seed=jax.random.PRNGKey(0), seed=jax.random.PRNGKey(0),
sharding=sharding) sharding=sharding)
#initial_conditions = ShardedArray(initial_conditions, sharding) #initial_conditions = ShardedArray(initial_conditions, sharding)
params = Params(omega_c, sigma8, initial_conditions) params = Params(omega_c, sigma8, initial_conditions)
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
@partial(jax.jit , static_argnums=(1 , 2,3,4 )) def forward_model(params, mesh_shape, box_size, halo_size, snapshots):
def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
# Create initial conditions # Create initial conditions
cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8) cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8)
particles = uniform_particles(mesh_shape , sharding) particles = uniform_particles(mesh_shape, sharding)
ic_structure = jax.tree.structure(params.initial_conditions) ic_structure = jax.tree.structure(params.initial_conditions)
particles = jax.tree.unflatten(ic_structure , jax.tree.leaves(particles)) particles = jax.tree.unflatten(ic_structure, jax.tree.leaves(particles))
# Initial displacement # Initial displacement
dx, p, f = lpt(cosmo, dx, p, f = lpt(cosmo,
params.initial_conditions, params.initial_conditions,
@ -90,10 +90,15 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
# Evolve the simulation forward # Evolve the simulation forward
ode_fn = ODETerm( ode_fn = ODETerm(
make_diffrax_ode(mesh_shape, paint_absolute_pos=True,halo_size=halo_size,sharding=sharding)) make_diffrax_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=halo_size,
sharding=sharding))
solver = LeapfrogMidpoint() solver = LeapfrogMidpoint()
y0 = jax.tree.map(lambda particles , dx , p : jnp.stack([particles + dx ,p],axis=0) , particles , dx , p) y0 = jax.tree.map(
lambda particles, dx, p: jnp.stack([particles + dx, p], axis=0),
particles, dx, p)
print(f"y0 structure: {jax.tree.structure(y0)}") print(f"y0 structure: {jax.tree.structure(y0)}")
stepsize_controller = ConstantStepSize() stepsize_controller = ConstantStepSize()
@ -108,17 +113,16 @@ def forward_model(params , mesh_shape,box_size,halo_size , snapshots):
stepsize_controller=stepsize_controller) stepsize_controller=stepsize_controller)
ode_solutions = [sol[0] for sol in res.ys] ode_solutions = [sol[0] for sol in res.ys]
ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), ode_solutions[-1]) ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32),
return particles + dx , ode_field ode_solutions[-1])
return particles + dx, ode_field
ode_field = cic_paint_dx(ode_solutions[-1]) ode_field = cic_paint_dx(ode_solutions[-1])
return dx , ode_field return dx, ode_field
lpt_particles, ode_field = forward_model(params, mesh_shape, box_size,
lpt_particles , ode_field = forward_model(params , mesh_shape,box_size,halo_size , snapshots) halo_size, snapshots)
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -127,11 +131,11 @@ lpt_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), lpt_particles)
plt.figure(figsize=(12, 6)) plt.figure(figsize=(12, 6))
plt.subplot(121) plt.subplot(121)
plt.imshow(lpt_field.sum(axis=0) , cmap='magma') plt.imshow(lpt_field.sum(axis=0), cmap='magma')
plt.colorbar() plt.colorbar()
plt.title('LPT field') plt.title('LPT field')
plt.subplot(122) plt.subplot(122)
plt.imshow(ode_field.sum(axis=0) , cmap='magma') plt.imshow(ode_field.sum(axis=0), cmap='magma')
plt.colorbar() plt.colorbar()
plt.title('ODE field') plt.title('ODE field')
plt.show() plt.show()