update code

This commit is contained in:
Wassim Kabalan 2024-12-06 18:56:24 +01:00
parent e0c118a540
commit 21373b89ee
7 changed files with 84 additions and 100 deletions

View file

@ -82,7 +82,7 @@ def slice_unpad_impl(x, pad_width):
def slice_pad(x, pad_width, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty) and (
if gpu_mesh is not None and not (gpu_mesh.empty) and (
pad_width[0][0] > 0 or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
@ -96,7 +96,7 @@ def slice_pad(x, pad_width, sharding):
def slice_unpad(x, pad_width, sharding):
mesh = sharding.mesh if sharding is not None else None
if not mesh is None and not (mesh.empty) and (pad_width[0][0] > 0
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
@ -122,20 +122,6 @@ def get_local_shape(mesh_shape, sharding=None):
]
def zeros(mesh_shape, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
return shard_map(
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
mesh=gpu_mesh,
in_specs=(),
out_specs=spec)() # yapf: disable
else:
return jnp.zeros(mesh_shape)
def __axis_names(spec):
if len(spec) == 1:
x_axis, = spec
@ -158,7 +144,7 @@ def __axis_names(spec):
def uniform_particles(mesh_shape, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
x_axis, y_axis, single_axis = __axis_names(spec)
@ -183,7 +169,7 @@ def uniform_particles(mesh_shape, sharding=None):
def normal_field(mesh_shape, seed, sharding=None):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if not gpu_mesh is None and not (gpu_mesh.empty):
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
size = jax.device_count()

View file

@ -1,5 +1,4 @@
import jax.numpy as jnp
import jax_cosmo as jc
import numpy as np
from jax.lib.xla_client import FftType
from jax.sharding import PartitionSpec as P

View file

@ -204,7 +204,7 @@ def cic_paint_dx(displacements,
return grid_mesh
def cic_read_dx_impl(grid_mesh, halo_size):
def cic_read_dx_impl(grid_mesh, disp, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
@ -220,14 +220,15 @@ def cic_read_dx_impl(grid_mesh, halo_size):
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])
return gather(pmid, jnp.zeros_like(pmid),
return gather(pmid, disp,
grid_mesh).reshape(original_shape)
@partial(jax.jit, static_argnums=(1, 2))
def cic_read_dx(grid_mesh, halo_size=0, sharding=None):
# return mesh
@partial(jax.jit, static_argnums=(2, 3))
def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
@ -238,7 +239,7 @@ def cic_read_dx(grid_mesh, halo_size=0, sharding=None):
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=(spec),
out_specs=spec)(grid_mesh)
out_specs=spec)(grid_mesh , disp)
return displacements

View file

@ -25,72 +25,71 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
return remainder, chunks
def enmesh(i1, d1, a1, s1, b12, a2, s2):
def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape):
"""Multilinear enmeshing."""
i1 = jnp.asarray(i1)
d1 = jnp.asarray(d1)
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1,
dtype=d1.dtype)
if s1 is not None:
s1 = jnp.array(s1, dtype=i1.dtype)
b12 = jnp.float64(b12)
if a2 is not None:
a2 = jnp.float64(a2)
if s2 is not None:
s2 = jnp.array(s2, dtype=i1.dtype)
cell_size = jnp.float64(cell_size) if new_cell_size is not None else jnp.array(cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
dim = i1.shape[1]
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>
jnp.arange(dim, dtype=i1.dtype)) & 1
spatial_dim = base_indices.shape[1]
neighbor_offsets = (jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
if a2 is not None:
P = i1 * a1 + d1 - b12
P = P[:, jnp.newaxis] # insert neighbor axis
i2 = P + neighbors * a2 # multilinear
if new_cell_size is not None:
particle_positions = base_indices * cell_size + displacements - offset
particle_positions = particle_positions[:, jnp.newaxis] # insert neighbor axis
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
if s1 is not None:
L = s1 * a1
i2 %= L
if base_shape is not None:
grid_length = base_shape * cell_size
new_indices %= grid_length
i2 //= a2
d2 = P - i2 * a2
new_indices //= new_cell_size
new_displacements = particle_positions - new_indices * new_cell_size
if s1 is not None:
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
if base_shape is not None:
new_displacements -= jnp.rint(new_displacements / grid_length) * grid_length # also abs(new_displacements) < new_cell_size is expected
i2 = i2.astype(i1.dtype)
d2 = d2.astype(d1.dtype)
a2 = a2.astype(d1.dtype)
new_indices = new_indices.astype(base_indices.dtype)
new_displacements = new_displacements.astype(displacements.dtype)
new_cell_size = new_cell_size.astype(displacements.dtype)
d2 /= a2
new_displacements /= new_cell_size
else:
i12, d12 = jnp.divmod(b12, a1)
i1 -= i12.astype(i1.dtype)
d1 -= d12.astype(d1.dtype)
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
base_indices -= offset_indices.astype(base_indices.dtype)
displacements -= offset_displacements.astype(displacements.dtype)
# insert neighbor axis
i1 = i1[:, jnp.newaxis]
d1 = d1[:, jnp.newaxis]
base_indices = base_indices[:, jnp.newaxis]
displacements = displacements[:, jnp.newaxis]
# multilinear
d1 /= a1
i2 = jnp.floor(d1).astype(i1.dtype)
i2 += neighbors
d2 = d1 - i2
i2 += i1
displacements /= cell_size
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
new_indices += neighbor_offsets
new_displacements = displacements - new_indices
new_indices += base_indices
if s1 is not None:
i2 %= s1
if base_shape is not None:
new_indices %= base_shape
f2 = 1 - jnp.abs(d2)
weights = 1 - jnp.abs(new_displacements)
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
i2 = jnp.where(i2 < 0, s2, i2)
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
f2 = f2.prod(axis=-1)
weights = weights.prod(axis=-1)
return i2, f2
return new_indices, weights
def _scatter_chunk(carry, chunk):
@ -138,7 +137,7 @@ def _chunk_cat(remainder_array, chunked_array):
return array
def gather(pmid, disp, mesh, chunk_size=2**24, val=1, offset=0, cell_size=1.):
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
mesh = jnp.asarray(mesh)

View file

@ -1,4 +1,3 @@
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

View file

@ -1,11 +1,9 @@
from functools import partial
import jax.numpy as jnp
import jax_cosmo as jc
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d,
normal_field, zeros)
from jaxpm.distributed import (fft3d, ifft3d,
normal_field)
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
growth_rate, growth_rate_second)
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
@ -29,17 +27,17 @@ def pm_forces(positions,
mesh_shape = delta.shape
if paint_absolute_pos:
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
x,
halo_size=halo_size,
sharding=sharding)
read_fn = lambda x: cic_read(
x, positions, halo_size=halo_size, sharding=sharding)
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape , device=sharding),
pos,
halo_size=halo_size,
sharding=sharding)
read_fn = lambda grid_mesh, pos: cic_read(
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
else:
paint_fn = partial(cic_paint_dx,
halo_size=halo_size,
sharding=sharding)
read_fn = partial(cic_read_dx, halo_size=halo_size, sharding=sharding)
paint_fn = lambda disp: cic_paint_dx(
disp, halo_size=halo_size, sharding=sharding)
read_fn = lambda grid_mesh, disp: cic_read_dx(
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
if delta is None:
field = paint_fn(positions)
@ -55,7 +53,7 @@ def pm_forces(positions,
kvec, r_split=r_split)
# Computes gravitational forces
forces = jnp.stack([
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
) for i in range(3)], axis=-1) # yapf: disable
return forces
@ -73,6 +71,8 @@ def lpt(cosmo,
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
"""
paint_absolute_pos = particles is not None
if particles is None:
particles = jnp.zeros_like(initial_conditions , shape=(*initial_conditions.shape , 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -167,25 +167,27 @@ def make_ode_fn(mesh_shape,
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
#dpos = dpos if not paint_absolute_pos else dpos + pos
return dpos, dvel
return nbody_ode
def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
def make_diffrax_ode(cosmo, mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(a, state, args):
"""
State is an array [position, velocities]
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(
pos, mesh_shape, halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
@ -197,7 +199,6 @@ def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
return nbody_ode
def pgd_correction(pos, mesh_shape, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,

View file

@ -5,7 +5,6 @@ import numpy as np
from jax.scipy.stats import norm
from scipy.special import legendre
from jaxpm.growth import growth_factor, growth_rate
__all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',