This commit is contained in:
Wassim Kabalan 2024-12-08 22:45:09 +01:00
parent 7823fdaf98
commit af29c4005d
7 changed files with 68 additions and 63 deletions

View file

@ -222,12 +222,11 @@ def cic_read_dx_impl(grid_mesh, disp, halo_size):
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])
return gather(pmid, disp,
grid_mesh).reshape(original_shape)
return gather(pmid, disp, grid_mesh).reshape(original_shape)
@partial(jax.jit, static_argnums=(2, 3))
def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
@ -239,7 +238,7 @@ def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=(spec),
out_specs=spec)(grid_mesh , disp)
out_specs=spec)(grid_mesh, disp)
return displacements

View file

@ -25,12 +25,15 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
return remainder, chunks
def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape):
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_cell_size, new_shape):
"""Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(cell_size) if new_cell_size is not None else jnp.array(cell_size, dtype=displacements.dtype)
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
@ -40,12 +43,14 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
spatial_dim = base_indices.shape[1]
neighbor_offsets = (jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
neighbor_offsets = (
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
if new_cell_size is not None:
particle_positions = base_indices * cell_size + displacements - offset
particle_positions = particle_positions[:, jnp.newaxis] # insert neighbor axis
particle_positions = particle_positions[:, jnp.
newaxis] # insert neighbor axis
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
if base_shape is not None:
@ -56,7 +61,9 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_
new_displacements = particle_positions - new_indices * new_cell_size
if base_shape is not None:
new_displacements -= jnp.rint(new_displacements / grid_length) * grid_length # also abs(new_displacements) < new_cell_size is expected
new_displacements -= jnp.rint(
new_displacements / grid_length
) * grid_length # also abs(new_displacements) < new_cell_size is expected
new_indices = new_indices.astype(base_indices.dtype)
new_displacements = new_displacements.astype(displacements.dtype)

View file

@ -1,9 +1,7 @@
import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.distributed import (fft3d, ifft3d,
normal_field)
from jaxpm.distributed import fft3d, ifft3d, normal_field
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
growth_rate, growth_rate_second)
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
@ -27,7 +25,8 @@ def pm_forces(positions,
mesh_shape = delta.shape
if paint_absolute_pos:
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape , device=sharding),
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
device=sharding),
pos,
halo_size=halo_size,
sharding=sharding)
@ -72,7 +71,8 @@ def lpt(cosmo,
"""
paint_absolute_pos = particles is not None
if particles is None:
particles = jnp.zeros_like(initial_conditions , shape=(*initial_conditions.shape , 3))
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
@ -172,10 +172,11 @@ def make_ode_fn(mesh_shape,
return nbody_ode
def make_diffrax_ode(cosmo, mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(a, state, args):
"""
@ -199,6 +200,7 @@ def make_diffrax_ode(cosmo, mesh_shape,
return nbody_ode
def pgd_correction(pos, mesh_shape, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,

View file

@ -5,7 +5,6 @@ import numpy as np
from jax.scipy.stats import norm
from scipy.special import legendre
__all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing'