mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-14 03:51:11 +00:00
format
This commit is contained in:
parent
20fe25c324
commit
1f5c619531
10 changed files with 290 additions and 210 deletions
|
@ -79,6 +79,7 @@ def slice_unpad_impl(x, pad_width):
|
|||
|
||||
return x[tuple(unpad_slice)]
|
||||
|
||||
|
||||
def slice_pad_impl(x, pad_width):
|
||||
return jax.tree.map(lambda x: jnp.pad(x, pad_width), x)
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.lax import FftType
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jaxdecomp import fftfreq3d, get_output_specs
|
||||
import jax
|
||||
|
||||
from jaxpm.distributed import autoshmap
|
||||
|
||||
|
||||
|
@ -25,7 +26,8 @@ def fftk(k_array):
|
|||
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||
|
||||
def pk_fn(input):
|
||||
return jax.tree.map(lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
|
||||
return jax.tree.map(
|
||||
lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape), input)
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
specs = sharding.spec if sharding is not None else P()
|
||||
|
@ -61,7 +63,8 @@ def gradient_kernel(kvec, direction, order=1):
|
|||
return wts
|
||||
else:
|
||||
w = kvec[direction]
|
||||
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)), w)
|
||||
a = jax.tree.map(lambda x: 1 / 6.0 * (8 * jnp.sin(x) - jnp.sin(2 * x)),
|
||||
w)
|
||||
wts = a * 1j
|
||||
return wts
|
||||
|
||||
|
@ -85,11 +88,14 @@ def invlaplace_kernel(kvec, fd=False):
|
|||
Complex kernel values
|
||||
"""
|
||||
if fd:
|
||||
kk = sum(jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki) for ki in kvec)
|
||||
kk = sum(
|
||||
jax.tree.map(lambda x: (x * jnp.sinc(x / (2 * jnp.pi)))**2, ki)
|
||||
for ki in kvec)
|
||||
else:
|
||||
kk = sum(jax.tree.map(lambda x: x**2, ki) for ki in kvec)
|
||||
kk_nozeros = jax.tree.map(lambda x: jnp.where(x == 0, 1, x), kk)
|
||||
return jax.tree.map(lambda x , y : -jnp.where(y == 0, 0, 1 / x), kk_nozeros, kk)
|
||||
return jax.tree.map(lambda x, y: -jnp.where(y == 0, 0, 1 / x), kk_nozeros,
|
||||
kk)
|
||||
|
||||
|
||||
def longrange_kernel(kvec, r_split):
|
||||
|
@ -131,7 +137,10 @@ def cic_compensation(kvec):
|
|||
wts: array
|
||||
Complex kernel values
|
||||
"""
|
||||
kwts = [jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i]) for i in range(3)]
|
||||
kwts = [
|
||||
jax.tree.map(lambda x: jnp.sinc(x / (2 * np.pi)), kvec[i])
|
||||
for i in range(3)
|
||||
]
|
||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||
return wts
|
||||
|
||||
|
|
|
@ -19,28 +19,36 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
"""
|
||||
|
||||
positions = positions.reshape([-1, 3])
|
||||
positions = jax.tree.map(lambda p : jnp.expand_dims(p , 1) , positions)
|
||||
floor = jax.tree.map(jnp.floor , positions)
|
||||
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
|
||||
floor = jax.tree.map(jnp.floor, positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jax.tree.map(jnp.abs , (positions - neighboor_coords))
|
||||
kernel = 1. - jax.tree.map(jnp.abs, (positions - neighboor_coords))
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
if jax.tree.all(jax.tree.map(jnp.isscalar, weight)):
|
||||
kernel = jax.tree.map(lambda k , w : jnp.multiply(jnp.expand_dims(w, axis=-1)
|
||||
, k) , kernel , weight)
|
||||
kernel = jax.tree.map(
|
||||
lambda k, w: jnp.multiply(jnp.expand_dims(w, axis=-1), k),
|
||||
kernel, weight)
|
||||
else:
|
||||
kernel = jax.tree.map(lambda k , w : jnp.multiply(w.reshape(*positions.shape[:-1]) , k) , kernel , weight)
|
||||
kernel = jax.tree.map(
|
||||
lambda k, w: jnp.multiply(w.reshape(*positions.shape[:-1]), k),
|
||||
kernel, weight)
|
||||
|
||||
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)) , neighboor_coords)
|
||||
neighboor_coords = jax.tree.map(
|
||||
lambda nc: jnp.mod(
|
||||
nc.reshape([-1, 8, 3]).astype('int32'), jnp.array(grid_mesh.shape)
|
||||
), neighboor_coords)
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = jax.tree.map(lambda g , nc , k : lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums) , grid_mesh , neighboor_coords , kernel)
|
||||
mesh = jax.tree.map(
|
||||
lambda g, nc, k: lax.scatter_add(g, nc, k.reshape([-1, 8]), dnums),
|
||||
grid_mesh, neighboor_coords, kernel)
|
||||
|
||||
return mesh
|
||||
|
||||
|
@ -49,7 +57,8 @@ def _cic_paint_impl(grid_mesh, positions, weight=None):
|
|||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
|
||||
positions_structure = jax.tree.structure(positions)
|
||||
grid_mesh = jax.tree.unflatten(positions_structure, jax.tree.leaves(grid_mesh))
|
||||
grid_mesh = jax.tree.unflatten(positions_structure,
|
||||
jax.tree.leaves(grid_mesh))
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||
|
@ -79,24 +88,27 @@ def _cic_read_impl(grid_mesh, positions):
|
|||
# Reshape positions to a flat list of 3D coordinates
|
||||
positions = positions.reshape([-1, 3])
|
||||
# Expand dimensions to calculate neighbor coordinates
|
||||
positions = jax.tree.map(lambda p : jnp.expand_dims(p, 1) , positions)
|
||||
positions = jax.tree.map(lambda p: jnp.expand_dims(p, 1), positions)
|
||||
# Floor the positions to get the base grid cell for each particle
|
||||
floor = jax.tree.map(jnp.floor , positions)
|
||||
floor = jax.tree.map(jnp.floor, positions)
|
||||
# Define connections to calculate all neighbor coordinates
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
# Calculate the 8 neighboring coordinates
|
||||
neighboor_coords = floor + connection
|
||||
# Calculate kernel weights based on distance from each neighboring coordinate
|
||||
kernel = 1. - jax.tree.map(jnp.abs , positions - neighboor_coords)
|
||||
kernel = 1. - jax.tree.map(jnp.abs, positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
# Modulo operation to wrap around edges if necessary
|
||||
neighboor_coords = jax.tree.map(lambda nc : jnp.mod(nc.astype('int32')
|
||||
,jnp.array(grid_mesh.shape)) , neighboor_coords)
|
||||
neighboor_coords = jax.tree.map(
|
||||
lambda nc: jnp.mod(nc.astype('int32'), jnp.array(grid_mesh.shape)),
|
||||
neighboor_coords)
|
||||
|
||||
# Ensure grid_mesh shape is as expected
|
||||
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
||||
grid_mesh = jax.tree.map(lambda g , nc , k : g[nc[...,0], nc[...,1], nc[...,2]] * k , grid_mesh , neighboor_coords , kernel)
|
||||
grid_mesh = jax.tree.map(
|
||||
lambda g, nc, k: g[nc[..., 0], nc[..., 1], nc[..., 2]] * k, grid_mesh,
|
||||
neighboor_coords, kernel)
|
||||
return grid_mesh.sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
||||
|
||||
|
||||
|
@ -157,7 +169,9 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
|||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jax.tree.map(lambda x : jnp.zeros(x.shape[:-1], dtype=displacements.dtype), displacements)
|
||||
particle_mesh = jax.tree.map(
|
||||
lambda x: jnp.zeros(x.shape[:-1], dtype=displacements.dtype),
|
||||
displacements)
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
|
@ -165,13 +179,18 @@ def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
|||
weight = weight.flatten()
|
||||
# Padding is forced to be zero in a single gpu run
|
||||
|
||||
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
|
||||
jnp.arange(x.shape[1]),
|
||||
jnp.arange(x.shape[2]),
|
||||
indexing='ij') , axis=0), particle_mesh)
|
||||
|
||||
particle_mesh = jax.tree.map(lambda x : jnp.pad(x, halo_size), particle_mesh)
|
||||
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
||||
a, b, c = jax.tree.map(
|
||||
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(x.shape[0]),
|
||||
jnp.arange(x.shape[1]),
|
||||
jnp.arange(x.shape[2]),
|
||||
indexing='ij'),
|
||||
axis=0), particle_mesh)
|
||||
|
||||
particle_mesh = jax.tree.map(lambda x: jnp.pad(x, halo_size),
|
||||
particle_mesh)
|
||||
pmid = jax.tree.map(
|
||||
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
|
||||
c)
|
||||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
|
@ -217,12 +236,16 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
|||
jnp.arange(original_shape[1]),
|
||||
jnp.arange(original_shape[2]),
|
||||
indexing='ij')
|
||||
a, b, c = jax.tree.map( lambda x : jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||
jnp.arange(original_shape[1]),
|
||||
jnp.arange(original_shape[2]),
|
||||
indexing='ij') , axis=0), grid_mesh)
|
||||
a, b, c = jax.tree.map(
|
||||
lambda x: jnp.stack(jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||
jnp.arange(original_shape[1]),
|
||||
jnp.arange(original_shape[2]),
|
||||
indexing='ij'),
|
||||
axis=0), grid_mesh)
|
||||
|
||||
pmid = jax.tree.map(lambda a, b, c : jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b, c)
|
||||
pmid = jax.tree.map(
|
||||
lambda a, b, c: jnp.stack([a + halo_x, b + halo_y, c], axis=-1), a, b,
|
||||
c)
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
disp = disp.reshape([-1, 3])
|
||||
|
||||
|
|
|
@ -28,8 +28,8 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
|
|||
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||
new_cell_size, new_shape):
|
||||
"""Multilinear enmeshing."""
|
||||
base_indices = jax.tree.map(jnp.asarray , base_indices)
|
||||
displacements = jax.tree.map(jnp.asarray , displacements)
|
||||
base_indices = jax.tree.map(jnp.asarray, base_indices)
|
||||
displacements = jax.tree.map(jnp.asarray, displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
cell_size = jnp.float64(
|
||||
cell_size) if new_cell_size is not None else jnp.array(
|
||||
|
@ -61,8 +61,8 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
|||
new_displacements = particle_positions - new_indices * new_cell_size
|
||||
|
||||
if base_shape is not None:
|
||||
new_displacements -= jax.tree.map(jnp.rint ,
|
||||
new_displacements / grid_length
|
||||
new_displacements -= jax.tree.map(
|
||||
jnp.rint, new_displacements / grid_length
|
||||
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||
|
||||
new_indices = new_indices.astype(base_indices.dtype)
|
||||
|
@ -89,7 +89,7 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
|||
if base_shape is not None:
|
||||
new_indices %= base_shape
|
||||
|
||||
weights = 1 - jax.tree.map(jnp.abs , new_displacements)
|
||||
weights = 1 - jax.tree.map(jnp.abs, new_displacements)
|
||||
|
||||
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||
|
@ -109,11 +109,15 @@ def _scatter_chunk(carry, chunk):
|
|||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
# scatter
|
||||
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
|
||||
ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
|
||||
ind)
|
||||
mesh_structure = jax.tree.structure(mesh)
|
||||
val_flat = jax.tree.leaves(val)
|
||||
val_tree = jax.tree.unflatten(mesh_structure, val_flat)
|
||||
mesh = jax.tree.map(lambda m , v , i, f : m.at[i].add(jnp.multiply(jnp.expand_dims(v, axis=-1), f)) , mesh , val_tree ,ind , frac)
|
||||
mesh = jax.tree.map(
|
||||
lambda m, v, i, f: m.at[i].add(
|
||||
jnp.multiply(jnp.expand_dims(v, axis=-1), f)), mesh, val_tree, ind,
|
||||
frac)
|
||||
carry = mesh, offset, cell_size, mesh_shape
|
||||
return carry, None
|
||||
|
||||
|
@ -125,10 +129,10 @@ def scatter(pmid,
|
|||
val=1.,
|
||||
offset=0,
|
||||
cell_size=1.):
|
||||
|
||||
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
val = jax.tree.map(jnp.asarray , val)
|
||||
mesh = jax.tree.map(jnp.asarray , mesh)
|
||||
val = jax.tree.map(jnp.asarray, val)
|
||||
mesh = jax.tree.map(jnp.asarray, mesh)
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
if remainder is not None:
|
||||
|
@ -151,9 +155,9 @@ def _chunk_cat(remainder_array, chunked_array):
|
|||
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
|
||||
mesh = jax.tree.map(jnp.asarray , mesh)
|
||||
mesh = jax.tree.map(jnp.asarray, mesh)
|
||||
|
||||
val = jax.tree.map(jnp.asarray , val)
|
||||
val = jax.tree.map(jnp.asarray, val)
|
||||
|
||||
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||
raise ValueError('channel shape mismatch: '
|
||||
|
@ -187,11 +191,15 @@ def _gather_chunk(carry, chunk):
|
|||
spatial_shape)
|
||||
|
||||
# gather
|
||||
ind = jax.tree.map(lambda x : tuple(x[..., i] for i in range(spatial_ndim)) , ind)
|
||||
ind = jax.tree.map(lambda x: tuple(x[..., i] for i in range(spatial_ndim)),
|
||||
ind)
|
||||
frac = jax.tree.map(lambda x: jnp.expand_dims(x, chan_axis), frac)
|
||||
ind_structure = jax.tree.structure(ind)
|
||||
frac_structure = jax.tree.structure(frac)
|
||||
mesh_structure = jax.tree.structure(mesh)
|
||||
val += jax.tree.map(lambda m , i , f : (m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1) , mesh , ind , frac)
|
||||
val += jax.tree.map(
|
||||
lambda m, i, f:
|
||||
(m.at[i].get(mode='drop', fill_value=0) * f).sum(axis=1), mesh, ind,
|
||||
frac)
|
||||
|
||||
return carry, val
|
||||
|
|
19
jaxpm/pm.py
19
jaxpm/pm.py
|
@ -1,3 +1,4 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
|
||||
|
@ -7,7 +8,7 @@ from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
|||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||
invlaplace_kernel, longrange_kernel)
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||
import jax
|
||||
|
||||
|
||||
def pm_forces(positions,
|
||||
mesh_shape=None,
|
||||
|
@ -52,10 +53,12 @@ def pm_forces(positions,
|
|||
kvec, r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
forces = [
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||
) for i in range(3)]
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k), positions)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
forces = jax.tree.map(lambda x ,y ,z : jnp.stack([x,y,z], axis=-1), forces[0], forces[1], forces[2])
|
||||
forces = jax.tree.map(lambda x, y, z: jnp.stack([x, y, z], axis=-1),
|
||||
forces[0], forces[1], forces[2])
|
||||
|
||||
return forces
|
||||
|
||||
|
@ -73,8 +76,9 @@ def lpt(cosmo,
|
|||
"""
|
||||
paint_absolute_pos = particles is not None
|
||||
if particles is None:
|
||||
particles = jax.tree.map(lambda ic : jnp.zeros_like(ic,
|
||||
shape=(*ic.shape, 3)) , initial_conditions)
|
||||
particles = jax.tree.map(
|
||||
lambda ic: jnp.zeros_like(ic, shape=(*ic.shape, 3)),
|
||||
initial_conditions)
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
|
@ -198,7 +202,8 @@ def make_diffrax_ode(mesh_shape,
|
|||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
return jax.tree.map(lambda dp , dv : jnp.stack([dp, dv],axis=0), dpos, dvel)
|
||||
return jax.tree.map(lambda dp, dv: jnp.stack([dp, dv], axis=0), dpos,
|
||||
dvel)
|
||||
|
||||
return nbody_ode
|
||||
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.scipy.stats import norm
|
||||
from scipy.special import legendre
|
||||
import jax
|
||||
|
||||
__all__ = [
|
||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||
]
|
||||
|
||||
|
||||
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
||||
"""
|
||||
Parameters
|
||||
|
@ -100,11 +101,12 @@ def power_spectrum(mesh,
|
|||
n_bins = len(kavg) + 2
|
||||
|
||||
# FFTs
|
||||
meshk = jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho') , mesh)
|
||||
meshk = jax.tree.map(lambda x: jnp.fft.fftn(x, norm='ortho'), mesh)
|
||||
if mesh2 is None:
|
||||
mmk = meshk.real**2 + meshk.imag**2
|
||||
else:
|
||||
mmk = meshk * jax.tree.map(lambda x : jnp.fft.fftn(x, norm='ortho').conj() , mesh2)
|
||||
mmk = meshk * jax.tree.map(
|
||||
lambda x: jnp.fft.fftn(x, norm='ortho').conj(), mesh2)
|
||||
|
||||
# Sum powers
|
||||
pk = jnp.empty((len(poles), n_bins))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue