jaxdecomp proto (#21)

* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
This commit is contained in:
Wassim KABALAN 2024-12-20 11:44:02 +01:00 committed by GitHub
parent a0a79277e5
commit df8602b318
26 changed files with 1871 additions and 434 deletions

198
jaxpm/distributed.py Normal file
View file

@ -0,0 +1,198 @@
from typing import Any, Callable, Hashable
Specs = Any
AxisName = Hashable
from functools import partial
import jax
import jax.numpy as jnp
import jaxdecomp
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import AbstractMesh, Mesh
from jax.sharding import PartitionSpec as P
def autoshmap(
f: Callable,
gpu_mesh: Mesh | AbstractMesh | None,
in_specs: Specs,
out_specs: Specs,
check_rep: bool = False,
auto: frozenset[AxisName] = frozenset()) -> Callable:
"""Helper function to wrap the provided function in a shard map if
the code is being executed in a mesh context."""
if gpu_mesh is None or gpu_mesh.empty:
return f
else:
return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto)
def fft3d(x):
return jaxdecomp.pfft3d(x)
def ifft3d(x):
return jaxdecomp.pifft3d(x).real
def get_halo_size(halo_size, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is None or gpu_mesh.empty:
zero_ext = (0, 0)
zero_tuple = (0, 0)
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
else:
pdims = gpu_mesh.devices.shape
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
if (halo_extents[0] > 0 or halo_extents[1] > 0):
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
else:
return x
def slice_unpad_impl(x, pad_width):
halo_x, _ = pad_width[0]
halo_y, _ = pad_width[1]
# Apply corrections along x
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
# Apply corrections along y
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
unpad_slice = [slice(None)] * 3
if halo_x > 0:
unpad_slice[0] = slice(halo_x, -halo_x)
if halo_y > 0:
unpad_slice[1] = slice(halo_y, -halo_y)
return x[tuple(unpad_slice)]
def slice_pad(x, pad_width, sharding):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty) and (
pad_width[0][0] > 0 or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
return shard_map((partial(jnp.pad, pad_width=pad_width)),
mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(x)
else:
return x
def slice_unpad(x, pad_width, sharding):
mesh = sharding.mesh if sharding is not None else None
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
or pad_width[1][0] > 0):
assert sharding is not None
spec = sharding.spec
return shard_map(partial(slice_unpad_impl, pad_width=pad_width),
mesh=mesh,
in_specs=spec,
out_specs=spec)(x)
else:
return x
def get_local_shape(mesh_shape, sharding=None):
""" Helper function to get the local size of a mesh given the global size.
"""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is None or gpu_mesh.empty:
return mesh_shape
else:
pdims = gpu_mesh.devices.shape
return [
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
*mesh_shape[2:]
]
def _axis_names(spec):
if len(spec) == 1:
x_axis, = spec
y_axis = None
single_axis = True
elif len(spec) == 2:
x_axis, y_axis = spec
if y_axis == None:
single_axis = True
elif x_axis == None:
x_axis = y_axis
single_axis = True
else:
single_axis = False
else:
raise ValueError("Only 1 or 2 axis sharding is supported")
return x_axis, y_axis, single_axis
def uniform_particles(mesh_shape, sharding=None):
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
spec = sharding.spec
x_axis, y_axis, single_axis = _axis_names(spec)
def particles():
x_indx = lax.axis_index(x_axis)
y_indx = 0 if single_axis else lax.axis_index(y_axis)
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
z = jnp.arange(local_mesh_shape[2])
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
out_specs=spec)()
else:
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
indexing='ij'),
axis=-1)
def normal_field(mesh_shape, seed, sharding=None):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):
local_mesh_shape = get_local_shape(mesh_shape, sharding)
size = jax.device_count()
# rank = jax.process_index()
# process_index is multi_host only
# to make the code work both in multi host and single controller we can do this trick
keys = jax.random.split(seed, size)
spec = sharding.spec
x_axis, y_axis, single_axis = _axis_names(spec)
def normal(keys, shape, dtype):
idx = lax.axis_index(x_axis)
if not single_axis:
y_index = lax.axis_index(y_axis)
x_size = lax.psum(1, axis_name=x_axis)
idx += y_index * x_size
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
return shard_map(
partial(normal, shape=local_mesh_shape, dtype='float32'),
mesh=gpu_mesh,
in_specs=P(None),
out_specs=spec)(keys) # yapf: disable
else:
return jax.random.normal(shape=mesh_shape, key=seed)

View file

@ -1,6 +1,6 @@
import jax.numpy as np
from jax.numpy import interp
from jax_cosmo.background import *
from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.scipy.ode import odeint
@ -587,5 +587,6 @@ def dGf2a(cosmo, a):
cache = cosmo._workspace['background.growth_factor']
f2p = cache['h2'] / cache['a'] * cache['g2']
f2p = interp(np.log(a), np.log(cache['a']), f2p)
E = E(cosmo, a)
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
E_a = E(cosmo, a)
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
3 * a**2 * E_a * D2f)

View file

@ -1,30 +1,46 @@
import jax.numpy as jnp
import numpy as np
from jax.lax import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs
from jaxpm.distributed import autoshmap
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
"""
Return wave-vectors for a given shape
def fftk(k_array):
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
Generate Fourier transform wave numbers for a given mesh.
k.append(kd.astype(dtype))
del kd, kdshape
return k
Args:
nc (int): Shape of the mesh grid.
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kx, ky, kz = fftfreq3d(k_array)
# to the order of dimensions in the transposed FFT
return kx, ky, kz
def interpolate_power_spectrum(input, k, pk, sharding=None):
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
out_specs = P(*get_output_specs(
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
return autoshmap(pk_fn,
gpu_mesh=gpu_mesh,
in_specs=out_specs,
out_specs=out_specs)(input)
def gradient_kernel(kvec, direction, order=1):
"""
Computes the gradient kernel in the requested direction
Parameters
-----------
kvec: list
@ -50,23 +66,30 @@ def gradient_kernel(kvec, direction, order=1):
return wts
def invlaplace_kernel(kvec):
def invlaplace_kernel(kvec, fd=False):
"""
Compute the inverse Laplace kernel
Compute the inverse Laplace kernel.
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
Parameters
-----------
kvec: list
List of wave-vectors
fd: bool
Finite difference kernel
Returns
--------
wts: array
Complex kernel values
"""
kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk==0, 1, kk)
return - jnp.where(kk==0, 0, 1 / kk_nozeros)
if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
else:
kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk == 0, 1, kk)
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
def longrange_kernel(kvec, r_split):
@ -79,12 +102,10 @@ def longrange_kernel(kvec, r_split):
List of wave-vectors
r_split: float
Splitting radius
Returns
--------
wts: array
Complex kernel values
TODO: @modichirag add documentation
"""
if r_split != 0:
@ -105,13 +126,12 @@ def cic_compensation(kvec):
-----------
kvec: list
List of wave-vectors
Returns:
--------
wts: array
Complex kernel values
"""
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts

View file

@ -1,15 +1,24 @@
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
ifft3d, slice_pad, slice_unpad)
from jaxpm.kernels import cic_compensation, fftk
from jaxpm.painting_utils import gather, scatter
def cic_paint(mesh, positions, weight=None):
def _cic_paint_impl(grid_mesh, positions, weight=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
mesh: [nx, ny, nz]
displacement field: [nx, ny, nz, 3]
"""
positions = positions.reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
@ -19,48 +28,106 @@ def cic_paint(mesh, positions, weight=None):
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
if weight is not None:
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
if jnp.isscalar(weight):
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
else:
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
kernel)
neighboor_coords = jnp.mod(
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
jnp.array(mesh.shape))
jnp.array(grid_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
dnums)
mesh = lax.scatter_add(grid_mesh, neighboor_coords,
kernel.reshape([-1, 8]), dnums)
return mesh
def cic_read(mesh, positions):
@partial(jax.jit, static_argnums=(3, 4))
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(_cic_paint_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec, P()),
out_specs=spec)(grid_mesh, positions, weight)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
return grid_mesh
def _cic_read_impl(grid_mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
mesh: [nx, ny, nz]
positions: [nx,ny,nz, 3]
"""
# Save original shape for reshaping output later
original_shape = positions.shape
# Reshape positions to a flat list of 3D coordinates
positions = positions.reshape([-1, 3])
# Expand dimensions to calculate neighbor coordinates
positions = jnp.expand_dims(positions, 1)
# Floor the positions to get the base grid cell for each particle
floor = jnp.floor(positions)
# Define connections to calculate all neighbor coordinates
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
# Calculate the 8 neighboring coordinates
neighboor_coords = floor + connection
# Calculate kernel weights based on distance from each neighboring coordinate
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
# Modulo operation to wrap around edges if necessary
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
jnp.array(mesh.shape))
jnp.array(grid_mesh.shape))
# Ensure grid_mesh shape is as expected
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
return (grid_mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
@partial(jax.jit, static_argnums=(2, 3))
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
original_shape = positions.shape
positions = positions.reshape((*grid_mesh.shape, 3))
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacement = autoshmap(_cic_read_impl,
gpu_mesh=gpu_mesh,
in_specs=(spec, spec),
out_specs=spec)(grid_mesh, positions)
return displacement.reshape(original_shape[:-1])
def cic_paint_2d(mesh, positions, weight):
""" Paints positions onto a 2d mesh
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
"""
mesh: [nx, ny]
positions: [npart, 2]
weight: [npart]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
@ -84,17 +151,109 @@ def cic_paint_2d(mesh, positions, weight):
return mesh
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = displacements.shape
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
if not jnp.isscalar(weight):
if weight.shape != original_shape[:-1]:
raise ValueError("Weight shape must match particle shape")
else:
weight = weight.flatten()
# Padding is forced to be zero in a single gpu run
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
jnp.arange(particle_mesh.shape[1]),
jnp.arange(particle_mesh.shape[2]),
indexing='ij')
particle_mesh = jnp.pad(particle_mesh, halo_size)
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
return scatter(pmid.reshape([-1, 3]),
displacements.reshape([-1, 3]),
particle_mesh,
chunk_size=2**24,
val=weight)
@partial(jax.jit, static_argnums=(1, 2, 4))
def cic_paint_dx(displacements,
halo_size=0,
sharding=None,
weight=1.0,
chunk_size=2**24):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
halo_size=halo_size,
weight=weight,
chunk_size=chunk_size),
gpu_mesh=gpu_mesh,
in_specs=spec,
out_specs=spec)(displacements)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
return grid_mesh
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
halo_x, _ = halo_size[0]
halo_y, _ = halo_size[1]
original_shape = [
dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
]
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
jnp.arange(original_shape[1]),
jnp.arange(original_shape[2]),
indexing='ij')
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
pmid = pmid.reshape([-1, 3])
disp = disp.reshape([-1, 3])
return gather(pmid, disp, grid_mesh).reshape(original_shape)
@partial(jax.jit, static_argnums=(2, 3))
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,
halo_periods=(True, True))
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
gpu_mesh=gpu_mesh,
in_specs=(spec),
out_specs=spec)(grid_mesh, disp)
return displacements
def compensate_cic(field):
"""
Compensate for CiC painting
Args:
field: input 3D cic-painted field
Returns:
compensated_field
"""
nc = field.shape
kvec = fftk(nc)
Compensate for CiC painting
Args:
field: input 3D cic-painted field
Returns:
compensated_field
"""
delta_k = fft3d(field)
delta_k = jnp.fft.rfftn(field)
kvec = fftk(delta_k)
delta_k = cic_compensation(kvec) * delta_k
return jnp.fft.irfftn(delta_k)
return ifft3d(delta_k)

190
jaxpm/painting_utils.py Normal file
View file

@ -0,0 +1,190 @@
import jax
import jax.numpy as jnp
from jax.lax import scan
def _chunk_split(ptcl_num, chunk_size, *arrays):
"""Split and reshape particle arrays into chunks and remainders, with the remainders
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
remainder_size = ptcl_num % chunk_size
chunk_num = ptcl_num // chunk_size
remainder = None
chunks = arrays
if remainder_size:
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
# `scan` triggers errors in scatter and gather without the `full`
chunks = [
x.reshape(chunk_num, chunk_size, *x.shape[1:])
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
]
return remainder, chunks
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
new_cell_size, new_shape):
"""Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
spatial_dim = base_indices.shape[1]
neighbor_offsets = (
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
if new_cell_size is not None:
particle_positions = base_indices * cell_size + displacements - offset
particle_positions = particle_positions[:, jnp.
newaxis] # insert neighbor axis
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
if base_shape is not None:
grid_length = base_shape * cell_size
new_indices %= grid_length
new_indices //= new_cell_size
new_displacements = particle_positions - new_indices * new_cell_size
if base_shape is not None:
new_displacements -= jnp.rint(
new_displacements / grid_length
) * grid_length # also abs(new_displacements) < new_cell_size is expected
new_indices = new_indices.astype(base_indices.dtype)
new_displacements = new_displacements.astype(displacements.dtype)
new_cell_size = new_cell_size.astype(displacements.dtype)
new_displacements /= new_cell_size
else:
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
base_indices -= offset_indices.astype(base_indices.dtype)
displacements -= offset_displacements.astype(displacements.dtype)
# insert neighbor axis
base_indices = base_indices[:, jnp.newaxis]
displacements = displacements[:, jnp.newaxis]
# multilinear
displacements /= cell_size
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
new_indices += neighbor_offsets
new_displacements = displacements - new_indices
new_indices += base_indices
if base_shape is not None:
new_indices %= base_shape
weights = 1 - jnp.abs(new_displacements)
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
weights = weights.prod(axis=-1)
return new_indices, weights
def _scatter_chunk(carry, chunk):
mesh, offset, cell_size, mesh_shape = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# scatter
ind = tuple(ind[..., i] for i in range(spatial_ndim))
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
carry = mesh, offset, cell_size, mesh_shape
return carry, None
def scatter(pmid,
disp,
mesh,
chunk_size=2**24,
val=1.,
offset=0,
cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
val = jnp.asarray(val)
mesh = jnp.asarray(mesh)
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size, mesh.shape
if remainder is not None:
carry = _scatter_chunk(carry, remainder)[0]
carry = scan(_scatter_chunk, carry, chunks)[0]
mesh = carry[0]
return mesh
def _chunk_cat(remainder_array, chunked_array):
"""Reshape and concatenate one remainder and one chunked particle arrays."""
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
if remainder_array is not None:
array = jnp.concatenate((remainder_array, array), axis=0)
return array
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
ptcl_num, spatial_ndim = pmid.shape
mesh = jnp.asarray(mesh)
val = jnp.asarray(val)
if mesh.shape[spatial_ndim:] != val.shape[1:]:
raise ValueError('channel shape mismatch: '
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
carry = mesh, offset, cell_size, mesh.shape
val_0 = None
if remainder is not None:
val_0 = _gather_chunk(carry, remainder)[1]
val = scan(_gather_chunk, carry, chunks)[1]
val = _chunk_cat(val_0, val)
return val
def _gather_chunk(carry, chunk):
mesh, offset, cell_size, mesh_shape = carry
pmid, disp, val = chunk
spatial_ndim = pmid.shape[1]
spatial_shape = mesh.shape[:spatial_ndim]
chan_ndim = mesh.ndim - spatial_ndim
chan_axis = tuple(range(-chan_ndim, 0))
# multilinear mesh indices and fractions
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
spatial_shape)
# gather
ind = tuple(ind[..., i] for i in range(spatial_ndim))
frac = jnp.expand_dims(frac, chan_axis)
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
return carry, val

129
jaxpm/plotting.py Normal file
View file

@ -0,0 +1,129 @@
import matplotlib.pyplot as plt
import numpy as np
def plot_fields(fields_dict, sum_over=None):
"""
Plots sum projections of 3D fields along different axes,
slicing only the first `sum_over` elements along each axis.
Args:
- fields: list of 3D arrays representing fields to plot
- names: list of names for each field, used in titles
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_rows = len(fields_dict)
nb_cols = 3
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
def plot_subplots(proj_axis, field, row, title):
slicing = [slice(None)] * field.ndim
slicing[proj_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over the specified axis and plot
axes[row, proj_axis].imshow(
field[slicing].sum(axis=proj_axis) + 1,
cmap='magma',
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
axes[row, proj_axis].set_xlabel('Mpc/h')
axes[row, proj_axis].set_ylabel('Mpc/h')
axes[row, proj_axis].set_title(title)
# Plot each field across the three axes
for i, (name, field) in enumerate(fields_dict.items()):
for proj_axis in range(3):
plot_subplots(proj_axis, field, i,
f"{name} projection {proj_axis}")
plt.tight_layout()
plt.show()
def plot_fields_single_projection(fields_dict,
sum_over=None,
project_axis=0,
vmin=None,
vmax=None,
colorbar=False):
"""
Plots a single projection (along axis 0) of 3D fields in a grid,
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
Args:
- fields_dict: dictionary where keys are field names and values are 3D arrays
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_fields = len(fields_dict)
nb_cols = 4 # Set number of images per row
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
fig, axes = plt.subplots(nb_rows,
nb_cols,
figsize=(5 * nb_cols, 5 * nb_rows))
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
for i, (name, field) in enumerate(fields_dict.items()):
row, col = divmod(i, nb_cols)
# Define the slice for the 0-axis projection
slicing = [slice(None)] * field.ndim
slicing[project_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over axis 0 and plot
a = axes[row,
col].imshow(field[slicing].sum(axis=project_axis) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]],
vmin=vmin,
vmax=vmax)
axes[row, col].set_xlabel('Mpc/h')
axes[row, col].set_ylabel('Mpc/h')
axes[row, col].set_title(f"{name} projection 0")
if colorbar:
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
# Remove any empty subplots
for j in range(i + 1, nb_rows * nb_cols):
fig.delaxes(axes.flatten()[j])
plt.tight_layout()
plt.show()
def stack_slices(array):
"""
Stacks 2D slices of an array into a single array based on provided partition dimensions.
Args:
- array_slices: a 2D list of array slices (list of lists format) where
array_slices[i][j] is the slice located at row i, column j in the grid.
- pdims: a tuple representing the grid dimensions (rows, columns).
Returns:
- A single array constructed by stacking the slices.
"""
# Initialize an empty list to store the vertically stacked rows
pdims = array.sharding.mesh.devices.shape
field_slices = []
# Iterate over rows in pdims[0]
for i in range(pdims[0]):
row_slices = []
# Iterate over columns in pdims[1]
for j in range(pdims[1]):
slice_index = i * pdims[0] + j
row_slices.append(array.addressable_data(slice_index))
# Stack the current row of slices vertically
stacked_row = np.hstack(row_slices)
field_slices.append(stacked_row)
# Stack all rows horizontally to form the full array
full_array = np.vstack(field_slices)
return full_array

View file

@ -1,50 +1,92 @@
import jax
import jax.numpy as jnp
import jax_cosmo as jc
from jax_cosmo import Cosmology
from jaxpm.growth import growth_factor, growth_rate, dGfa, growth_factor_second, growth_rate_second, dGf2a
from jaxpm.kernels import PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel
from jaxpm.painting import cic_paint, cic_read
from jaxpm.distributed import fft3d, ifft3d, normal_field
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
growth_rate, growth_rate_second)
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
invlaplace_kernel, longrange_kernel)
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
def pm_forces(positions, mesh_shape, delta=None, r_split=0):
def pm_forces(positions,
mesh_shape=None,
delta=None,
r_split=0,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
"""
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
assert (delta is not None),\
"If mesh_shape is not provided, delta should be provided"
mesh_shape = delta.shape
if paint_absolute_pos:
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
device=sharding),
pos,
halo_size=halo_size,
sharding=sharding)
read_fn = lambda grid_mesh, pos: cic_read(
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
else:
paint_fn = lambda disp: cic_paint_dx(
disp, halo_size=halo_size, sharding=sharding)
read_fn = lambda grid_mesh, disp: cic_read_dx(
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
if delta is None:
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions))
field = paint_fn(positions)
delta_k = fft3d(field)
elif jnp.isrealobj(delta):
delta_k = jnp.fft.rfftn(delta)
delta_k = fft3d(delta)
else:
delta_k = delta
kvec = fftk(delta_k)
# Computes gravitational potential
kvec = fftk(mesh_shape)
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
kvec, r_split=r_split)
# Computes gravitational forces
return jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i) * pot_k), positions)
for i in range(3)], axis=-1)
forces = jnp.stack([
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
) for i in range(3)], axis=-1) # yapf: disable
return forces
def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
def lpt(cosmo,
initial_conditions,
particles=None,
a=0.1,
halo_size=0,
sharding=None,
order=1):
"""
Computes first and second order LPT displacement and momentum,
Computes first and second order LPT displacement and momentum,
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
"""
paint_absolute_pos = particles is not None
if particles is None:
particles = jnp.zeros_like(initial_conditions,
shape=(*initial_conditions.shape, 3))
a = jnp.atleast_1d(a)
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = jnp.fft.rfftn(init_mesh) # TODO: pass the modes directly to save one or two fft?
mesh_shape = init_mesh.shape
init_force = pm_forces(positions, mesh_shape, delta=delta_k)
dx = growth_factor(cosmo, a) * init_force
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
delta_k = fft3d(initial_conditions)
initial_force = pm_forces(particles,
delta=delta_k,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * E * dx
f = a**2 * E * dGfa(cosmo, a) * init_force
f = a**2 * E * dGfa(cosmo, a) * initial_force
if order == 2:
kvec = fftk(mesh_shape)
kvec = fftk(delta_k)
pot_k = delta_k * invlaplace_kernel(kvec)
delta2 = 0
@ -54,47 +96,58 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k)
delta2 += shear_ii * shear_acc
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
delta2 += shear_ii * shear_acc
shear_acc += shear_ii
# for kj in kvec[i+1:]:
for j in range(i+1, 3):
for j in range(i + 1, 3):
# Substract squared strict-up-triangle terms
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j)
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
kvec, j)
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
init_force2 = pm_forces(positions, mesh_shape, delta=jnp.fft.rfftn(delta2))
delta_k2 = fft3d(delta2)
init_force2 = pm_forces(particles,
delta=delta_k2,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding)
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
f2 = a**2 * E * dGf2a(cosmo, a) * init_force2
dx += dx2
p += p2
f += f2
p += p2
f += f2
return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed):
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
"""
Generate initial conditions.
"""
kvec = fftk(mesh_shape)
# Initialize a random field with one slice on each gpu
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
field = fft3d(field)
kvec = fftk(field)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
field = field * (pkmesh)**0.5
field = ifft3d(field)
return field
def make_ode_fn(mesh_shape):
def make_ode_fn(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(state, a, cosmo):
"""
@ -102,7 +155,11 @@ def make_ode_fn(mesh_shape):
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
@ -114,20 +171,28 @@ def make_ode_fn(mesh_shape):
return nbody_ode
def get_ode_fn(cosmo:Cosmology, mesh_shape):
def make_diffrax_ode(cosmo,
mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def nbody_ode(a, state, args):
"""
State is an array [position, velocities]
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
forces = pm_forces(pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
@ -138,51 +203,57 @@ def get_ode_fn(cosmo:Cosmology, mesh_shape):
def pgd_correction(pos, mesh_shape, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
based on https://arxiv.org/abs/1804.00671
args:
pos: particle positions [npart, 3]
params: [alpha, kl, ks] pgd parameters
"""
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = fft3d(delta)
kvec = fftk(delta_k)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range=PGD_kernel(kvec, kl, ks)
pot_k_pgd=(delta_k * invlaplace_kernel(kvec))*PGD_range
PGD_range = PGD_kernel(kvec, kl, ks)
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
forces_pgd = jnp.stack([
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
for i in range(3)
],
axis=-1)
dpos_pgd = forces_pgd * alpha
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos)
for i in range(3)],axis=-1)
dpos_pgd = forces_pgd*alpha
return dpos_pgd
def make_neural_ode_fn(model, mesh_shape):
def neural_nbody_ode(state, a, cosmo:Cosmology, params):
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
delta_k = jnp.fft.rfftn(delta)
delta_k = fft3d(delta)
kvec = fftk(delta_k)
# Computes gravitational potential
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
r_split=0)
# Apply a correction filter
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
# Computes gravitational forces
forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos)
for i in range(3)],axis=-1)
forces = jnp.stack([
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
for i in range(3)
],
axis=-1)
forces = forces * 1.5 * cosmo.Omega_m
@ -193,4 +264,5 @@ def make_neural_ode_fn(model, mesh_shape):
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
return neural_nbody_ode

View file

@ -1,47 +1,168 @@
from functools import partial
import jax.numpy as jnp
import numpy as np
from jax.scipy.stats import norm
from scipy.special import legendre
__all__ = ['power_spectrum']
__all__ = [
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
'cross_correlation_coefficients', 'gaussian_smoothing'
]
def _initialize_pk(shape, boxsize, kmin, dk):
def _initialize_pk(mesh_shape, box_shape, kedges, los):
"""
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
Parameters
----------
mesh_shape : tuple of int
Shape of the mesh grid.
box_shape : tuple of float
Physical dimensions of the box.
kedges : None, int, float, or list
If None, set dk to twice the minimum.
If int, specifies number of edges.
If float, specifies dk.
los : array_like
Line-of-sight vector.
Returns
-------
dig : ndarray
Indices of the bins to which each value in input array belongs.
kcount : ndarray
Count of values in each bin.
kedges : ndarray
Edges of the bins.
mumesh : ndarray
Mu values for the mesh grid.
"""
I = np.eye(len(shape), dtype='int') * -2 + 1
kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
W = np.empty(shape, dtype='f4')
W[...] = 2.0
W[..., 0] = 1.0
W[..., -1] = 1.0
if isinstance(kedges, None | int | float):
if kedges is None:
dk = 2 * np.pi / np.min(
box_shape) * 2 # twice the minimum wavenumber
if isinstance(kedges, int):
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
elif isinstance(kedges, float):
dk = kedges
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
kedges = np.arange(kmin, kmax, dk)
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
kmesh = sum(ki**2 for ki in kvec)**0.5
k = [
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
]
kmag = sum(ki**2 for ki in k)**0.5
dig = np.digitize(kmesh.reshape(-1), kedges)
kcount = np.bincount(dig, minlength=len(kedges) + 1)
xsum = np.zeros(len(kedges) + 1)
Nsum = np.zeros(len(kedges) + 1)
# Central value of each bin
# kavg = (kedges[1:] + kedges[:-1]) / 2
kavg = np.bincount(
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
kavg = kavg[1:-1]
dig = np.digitize(kmag.flat, kedges)
if los is None:
mumesh = 1.
else:
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
return dig, Nsum, xsum, W, k, kedges
return dig, kcount, kavg, mumesh
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
def power_spectrum(mesh,
mesh2=None,
box_shape=None,
kedges: int | float | list = None,
multipoles=0,
los=[0., 0., 1.]):
"""
Calculate the powerspectra given real space field
Compute the auto and cross spectrum of 3D fields, with multipoles.
"""
# Initialize
mesh_shape = np.array(mesh.shape)
if box_shape is None:
box_shape = mesh_shape
else:
box_shape = np.asarray(box_shape)
if multipoles == 0:
los = None
else:
los = np.asarray(los)
los = los / np.linalg.norm(los)
poles = np.atleast_1d(multipoles)
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
los)
n_bins = len(kavg) + 2
# FFTs
meshk = jnp.fft.fftn(mesh, norm='ortho')
if mesh2 is None:
mmk = meshk.real**2 + meshk.imag**2
else:
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
# Sum powers
pk = jnp.empty((len(poles), n_bins))
for i_ell, ell in enumerate(poles):
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
if mesh2 is None:
psum = jnp.bincount(dig, weights=weights, length=n_bins)
else: # XXX: bincount is really slow with complex numbers
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
psum = (psum_real**2 + psum_imag**2)**.5
pk = pk.at[i_ell].set(psum)
# Normalization and conversion from cell units to [Mpc/h]^3
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
# pk = jnp.concatenate([kavg[None], pk])
if np.ndim(multipoles) == 0:
return kavg, pk[0]
else:
return kavg, pk
def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, (pk1 / pk0)**.5
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk01 / (pk0 * pk1)**.5
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
ks, pk01 = pk_fn(mesh0, mesh1)
ks, pk0 = pk_fn(mesh0)
ks, pk1 = pk_fn(mesh1)
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
def cross_correlation_coefficients(field_a,
field_b,
kmin=5,
dk=0.5,
boxsize=False):
"""
Calculate the cross correlation coefficients given two real space field
Args:
field: real valued field
field_a: real valued field
field_b: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?)
@ -49,20 +170,21 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
Returns:
kbins: the central value of the bins for plotting
power: real valued array of power in each bin
P / norm: normalized cross correlation coefficient between two field a and b
"""
shape = field.shape
shape = field_a.shape
nx, ny, nz = shape
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#fast fourier transform
fft_image = jnp.fft.fftn(field)
fft_image_a = jnp.fft.fftn(field_a)
fft_image_b = jnp.fft.fftn(field_b)
#absolute value of fast fourier transform
pk = jnp.real(fft_image * jnp.conj(fft_image))
pk = fft_image_a * jnp.conj(fft_image_b)
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
@ -83,55 +205,6 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
return kbins, P / norm
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
"""
Calculate the cross correlation coefficients given two real space field
Args:
field_a: real valued field
field_b: real valued field
kmin: minimum k-value for binned powerspectra
dk: differential in each kbin
boxsize: length of each boxlength (can be strangly shaped?)
Returns:
kbins: the central value of the bins for plotting
P / norm: normalized cross correlation coefficient between two field a and b
"""
shape = field_a.shape
nx, ny, nz = shape
#initialze values related to powerspectra (mode bins and weights)
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
#fast fourier transform
fft_image_a = jnp.fft.fftn(field_a)
fft_image_b = jnp.fft.fftn(field_b)
#absolute value of fast fourier transform
pk = fft_image_a * jnp.conj(fft_image_b)
#calculating powerspectra
real = jnp.real(pk).reshape([-1])
imag = jnp.imag(pk).reshape([-1])
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
#normalization for powerspectra
norm = np.prod(np.array(shape[:])).astype('float32')**2
#find central values of each bin
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
return kbins, P / norm
def gaussian_smoothing(im, sigma):
"""
im: 2d image