forked from guilhem_lavaux/JaxPM
jaxdecomp proto (#21)
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
This commit is contained in:
parent
a0a79277e5
commit
df8602b318
26 changed files with 1871 additions and 434 deletions
198
jaxpm/distributed.py
Normal file
198
jaxpm/distributed.py
Normal file
|
@ -0,0 +1,198 @@
|
|||
from typing import Any, Callable, Hashable
|
||||
|
||||
Specs = Any
|
||||
AxisName = Hashable
|
||||
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxdecomp
|
||||
from jax import lax
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import AbstractMesh, Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
|
||||
def autoshmap(
|
||||
f: Callable,
|
||||
gpu_mesh: Mesh | AbstractMesh | None,
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
check_rep: bool = False,
|
||||
auto: frozenset[AxisName] = frozenset()) -> Callable:
|
||||
"""Helper function to wrap the provided function in a shard map if
|
||||
the code is being executed in a mesh context."""
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
return f
|
||||
else:
|
||||
return shard_map(f, gpu_mesh, in_specs, out_specs, check_rep, auto)
|
||||
|
||||
|
||||
def fft3d(x):
|
||||
return jaxdecomp.pfft3d(x)
|
||||
|
||||
|
||||
def ifft3d(x):
|
||||
return jaxdecomp.pifft3d(x).real
|
||||
|
||||
|
||||
def get_halo_size(halo_size, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
zero_ext = (0, 0)
|
||||
zero_tuple = (0, 0)
|
||||
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
|
||||
else:
|
||||
pdims = gpu_mesh.devices.shape
|
||||
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
|
||||
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)
|
||||
|
||||
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
|
||||
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
|
||||
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext))
|
||||
|
||||
|
||||
def halo_exchange(x, halo_extents, halo_periods=(True, True)):
|
||||
if (halo_extents[0] > 0 or halo_extents[1] > 0):
|
||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def slice_unpad_impl(x, pad_width):
|
||||
|
||||
halo_x, _ = pad_width[0]
|
||||
halo_y, _ = pad_width[1]
|
||||
# Apply corrections along x
|
||||
x = x.at[halo_x:halo_x + halo_x // 2].add(x[:halo_x // 2])
|
||||
x = x.at[-(halo_x + halo_x // 2):-halo_x].add(x[-halo_x // 2:])
|
||||
# Apply corrections along y
|
||||
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||
|
||||
unpad_slice = [slice(None)] * 3
|
||||
if halo_x > 0:
|
||||
unpad_slice[0] = slice(halo_x, -halo_x)
|
||||
if halo_y > 0:
|
||||
unpad_slice[1] = slice(halo_y, -halo_y)
|
||||
|
||||
return x[tuple(unpad_slice)]
|
||||
|
||||
|
||||
def slice_pad(x, pad_width, sharding):
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty) and (
|
||||
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
return shard_map((partial(jnp.pad, pad_width=pad_width)),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def slice_unpad(x, pad_width, sharding):
|
||||
mesh = sharding.mesh if sharding is not None else None
|
||||
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
or pad_width[1][0] > 0):
|
||||
assert sharding is not None
|
||||
spec = sharding.spec
|
||||
return shard_map(partial(slice_unpad_impl, pad_width=pad_width),
|
||||
mesh=mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_local_shape(mesh_shape, sharding=None):
|
||||
""" Helper function to get the local size of a mesh given the global size.
|
||||
"""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is None or gpu_mesh.empty:
|
||||
return mesh_shape
|
||||
else:
|
||||
pdims = gpu_mesh.devices.shape
|
||||
return [
|
||||
mesh_shape[0] // pdims[0], mesh_shape[1] // pdims[1],
|
||||
*mesh_shape[2:]
|
||||
]
|
||||
|
||||
|
||||
def _axis_names(spec):
|
||||
if len(spec) == 1:
|
||||
x_axis, = spec
|
||||
y_axis = None
|
||||
single_axis = True
|
||||
elif len(spec) == 2:
|
||||
x_axis, y_axis = spec
|
||||
if y_axis == None:
|
||||
single_axis = True
|
||||
elif x_axis == None:
|
||||
x_axis = y_axis
|
||||
single_axis = True
|
||||
else:
|
||||
single_axis = False
|
||||
else:
|
||||
raise ValueError("Only 1 or 2 axis sharding is supported")
|
||||
return x_axis, y_axis, single_axis
|
||||
|
||||
|
||||
def uniform_particles(mesh_shape, sharding=None):
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def particles():
|
||||
x_indx = lax.axis_index(x_axis)
|
||||
y_indx = 0 if single_axis else lax.axis_index(y_axis)
|
||||
|
||||
x = jnp.arange(local_mesh_shape[0]) + x_indx * local_mesh_shape[0]
|
||||
y = jnp.arange(local_mesh_shape[1]) + y_indx * local_mesh_shape[1]
|
||||
z = jnp.arange(local_mesh_shape[2])
|
||||
return jnp.stack(jnp.meshgrid(x, y, z, indexing='ij'), axis=-1)
|
||||
|
||||
return shard_map(particles, mesh=gpu_mesh, in_specs=(),
|
||||
out_specs=spec)()
|
||||
else:
|
||||
return jnp.stack(jnp.meshgrid(*[jnp.arange(s) for s in mesh_shape],
|
||||
indexing='ij'),
|
||||
axis=-1)
|
||||
|
||||
|
||||
def normal_field(mesh_shape, seed, sharding=None):
|
||||
"""Generate a Gaussian random field with the given power spectrum."""
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||
|
||||
size = jax.device_count()
|
||||
# rank = jax.process_index()
|
||||
# process_index is multi_host only
|
||||
# to make the code work both in multi host and single controller we can do this trick
|
||||
keys = jax.random.split(seed, size)
|
||||
spec = sharding.spec
|
||||
x_axis, y_axis, single_axis = _axis_names(spec)
|
||||
|
||||
def normal(keys, shape, dtype):
|
||||
idx = lax.axis_index(x_axis)
|
||||
if not single_axis:
|
||||
y_index = lax.axis_index(y_axis)
|
||||
x_size = lax.psum(1, axis_name=x_axis)
|
||||
idx += y_index * x_size
|
||||
|
||||
return jax.random.normal(key=keys[idx], shape=shape, dtype=dtype)
|
||||
|
||||
return shard_map(
|
||||
partial(normal, shape=local_mesh_shape, dtype='float32'),
|
||||
mesh=gpu_mesh,
|
||||
in_specs=P(None),
|
||||
out_specs=spec)(keys) # yapf: disable
|
||||
else:
|
||||
return jax.random.normal(shape=mesh_shape, key=seed)
|
|
@ -1,6 +1,6 @@
|
|||
import jax.numpy as np
|
||||
from jax.numpy import interp
|
||||
from jax_cosmo.background import *
|
||||
from jax_cosmo.scipy.interpolate import interp
|
||||
from jax_cosmo.scipy.ode import odeint
|
||||
|
||||
|
||||
|
@ -587,5 +587,6 @@ def dGf2a(cosmo, a):
|
|||
cache = cosmo._workspace['background.growth_factor']
|
||||
f2p = cache['h2'] / cache['a'] * cache['g2']
|
||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||
E = E(cosmo, a)
|
||||
return (f2p * a**3 * E + D2f * a**3 * dEa(cosmo, a) + 3 * a**2 * E * D2f)
|
||||
E_a = E(cosmo, a)
|
||||
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
|
||||
3 * a**2 * E_a * D2f)
|
||||
|
|
|
@ -1,30 +1,46 @@
|
|||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.lax import FftType
|
||||
from jax.sharding import PartitionSpec as P
|
||||
from jaxdecomp import fftfreq3d, get_output_specs
|
||||
|
||||
from jaxpm.distributed import autoshmap
|
||||
|
||||
|
||||
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
|
||||
"""
|
||||
Return wave-vectors for a given shape
|
||||
def fftk(k_array):
|
||||
"""
|
||||
k = []
|
||||
for d in range(len(shape)):
|
||||
kd = np.fft.fftfreq(shape[d])
|
||||
kd *= 2 * np.pi
|
||||
kdshape = np.ones(len(shape), dtype='int')
|
||||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
kdshape[d] = len(kd)
|
||||
kd = kd.reshape(kdshape)
|
||||
Generate Fourier transform wave numbers for a given mesh.
|
||||
|
||||
k.append(kd.astype(dtype))
|
||||
del kd, kdshape
|
||||
return k
|
||||
Args:
|
||||
nc (int): Shape of the mesh grid.
|
||||
|
||||
Returns:
|
||||
list: List of wave number arrays for each dimension in
|
||||
the order [kx, ky, kz].
|
||||
"""
|
||||
kx, ky, kz = fftfreq3d(k_array)
|
||||
# to the order of dimensions in the transposed FFT
|
||||
return kx, ky, kz
|
||||
|
||||
|
||||
def interpolate_power_spectrum(input, k, pk, sharding=None):
|
||||
|
||||
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
|
||||
|
||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||
specs = sharding.spec if sharding is not None else P()
|
||||
out_specs = P(*get_output_specs(
|
||||
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
|
||||
|
||||
return autoshmap(pk_fn,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=out_specs,
|
||||
out_specs=out_specs)(input)
|
||||
|
||||
|
||||
def gradient_kernel(kvec, direction, order=1):
|
||||
"""
|
||||
Computes the gradient kernel in the requested direction
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
kvec: list
|
||||
|
@ -50,23 +66,30 @@ def gradient_kernel(kvec, direction, order=1):
|
|||
return wts
|
||||
|
||||
|
||||
def invlaplace_kernel(kvec):
|
||||
def invlaplace_kernel(kvec, fd=False):
|
||||
"""
|
||||
Compute the inverse Laplace kernel
|
||||
Compute the inverse Laplace kernel.
|
||||
|
||||
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
kvec: list
|
||||
List of wave-vectors
|
||||
fd: bool
|
||||
Finite difference kernel
|
||||
|
||||
Returns
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel values
|
||||
"""
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
kk_nozeros = jnp.where(kk==0, 1, kk)
|
||||
return - jnp.where(kk==0, 0, 1 / kk_nozeros)
|
||||
if fd:
|
||||
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
|
||||
else:
|
||||
kk = sum(ki**2 for ki in kvec)
|
||||
kk_nozeros = jnp.where(kk == 0, 1, kk)
|
||||
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
|
||||
|
||||
|
||||
def longrange_kernel(kvec, r_split):
|
||||
|
@ -79,12 +102,10 @@ def longrange_kernel(kvec, r_split):
|
|||
List of wave-vectors
|
||||
r_split: float
|
||||
Splitting radius
|
||||
|
||||
Returns
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel values
|
||||
|
||||
TODO: @modichirag add documentation
|
||||
"""
|
||||
if r_split != 0:
|
||||
|
@ -105,13 +126,12 @@ def cic_compensation(kvec):
|
|||
-----------
|
||||
kvec: list
|
||||
List of wave-vectors
|
||||
|
||||
Returns:
|
||||
--------
|
||||
wts: array
|
||||
Complex kernel values
|
||||
"""
|
||||
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
|
||||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||
return wts
|
||||
|
||||
|
|
|
@ -1,15 +1,24 @@
|
|||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange,
|
||||
ifft3d, slice_pad, slice_unpad)
|
||||
from jaxpm.kernels import cic_compensation, fftk
|
||||
from jaxpm.painting_utils import gather, scatter
|
||||
|
||||
|
||||
def cic_paint(mesh, positions, weight=None):
|
||||
def _cic_paint_impl(grid_mesh, positions, weight=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
mesh: [nx, ny, nz]
|
||||
displacement field: [nx, ny, nz, 3]
|
||||
"""
|
||||
|
||||
positions = positions.reshape([-1, 3])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
|
@ -19,48 +28,106 @@ def cic_paint(mesh, positions, weight=None):
|
|||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
if weight is not None:
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
if jnp.isscalar(weight):
|
||||
kernel = jnp.multiply(jnp.expand_dims(weight, axis=-1), kernel)
|
||||
else:
|
||||
kernel = jnp.multiply(weight.reshape(*positions.shape[:-1]),
|
||||
kernel)
|
||||
|
||||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
jnp.array(mesh.shape))
|
||||
jnp.array(grid_mesh.shape))
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = lax.scatter_add(mesh, neighboor_coords, kernel.reshape([-1, 8]),
|
||||
dnums)
|
||||
mesh = lax.scatter_add(grid_mesh, neighboor_coords,
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
return mesh
|
||||
|
||||
|
||||
def cic_read(mesh, positions):
|
||||
@partial(jax.jit, static_argnums=(3, 4))
|
||||
def cic_paint(grid_mesh, positions, weight=None, halo_size=0, sharding=None):
|
||||
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding)
|
||||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
grid_mesh = autoshmap(_cic_paint_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec, P()),
|
||||
out_specs=spec)(grid_mesh, positions, weight)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||
|
||||
return grid_mesh
|
||||
|
||||
|
||||
def _cic_read_impl(grid_mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [nx,ny,nz, 3]
|
||||
"""
|
||||
# Save original shape for reshaping output later
|
||||
original_shape = positions.shape
|
||||
# Reshape positions to a flat list of 3D coordinates
|
||||
positions = positions.reshape([-1, 3])
|
||||
# Expand dimensions to calculate neighbor coordinates
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
# Floor the positions to get the base grid cell for each particle
|
||||
floor = jnp.floor(positions)
|
||||
# Define connections to calculate all neighbor coordinates
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
# Calculate the 8 neighboring coordinates
|
||||
neighboor_coords = floor + connection
|
||||
# Calculate kernel weights based on distance from each neighboring coordinate
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
# Modulo operation to wrap around edges if necessary
|
||||
neighboor_coords = jnp.mod(neighboor_coords.astype('int32'),
|
||||
jnp.array(mesh.shape))
|
||||
jnp.array(grid_mesh.shape))
|
||||
# Ensure grid_mesh shape is as expected
|
||||
# Retrieve values from grid_mesh at each neighboring coordinate and multiply by kernel
|
||||
return (grid_mesh[neighboor_coords[..., 0],
|
||||
neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 2]] * kernel).sum(axis=-1).reshape(original_shape[:-1]) # yapf: disable
|
||||
|
||||
return (mesh[neighboor_coords[..., 0], neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]] * kernel).sum(axis=-1)
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read(grid_mesh, positions, halo_size=0, sharding=None):
|
||||
|
||||
original_shape = positions.shape
|
||||
positions = positions.reshape((*grid_mesh.shape, 3))
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
|
||||
displacement = autoshmap(_cic_read_impl,
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec, spec),
|
||||
out_specs=spec)(grid_mesh, positions)
|
||||
|
||||
return displacement.reshape(original_shape[:-1])
|
||||
|
||||
|
||||
def cic_paint_2d(mesh, positions, weight):
|
||||
""" Paints positions onto a 2d mesh
|
||||
mesh: [nx, ny]
|
||||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
mesh: [nx, ny]
|
||||
positions: [npart, 2]
|
||||
weight: [npart]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[0, 0], [1., 0], [0., 1], [1., 1]])
|
||||
|
@ -84,17 +151,109 @@ def cic_paint_2d(mesh, positions, weight):
|
|||
return mesh
|
||||
|
||||
|
||||
def _cic_paint_dx_impl(displacements, halo_size, weight=1., chunk_size=2**24):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = displacements.shape
|
||||
particle_mesh = jnp.zeros(original_shape[:-1], dtype='float32')
|
||||
if not jnp.isscalar(weight):
|
||||
if weight.shape != original_shape[:-1]:
|
||||
raise ValueError("Weight shape must match particle shape")
|
||||
else:
|
||||
weight = weight.flatten()
|
||||
# Padding is forced to be zero in a single gpu run
|
||||
|
||||
a, b, c = jnp.meshgrid(jnp.arange(particle_mesh.shape[0]),
|
||||
jnp.arange(particle_mesh.shape[1]),
|
||||
jnp.arange(particle_mesh.shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
particle_mesh = jnp.pad(particle_mesh, halo_size)
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
return scatter(pmid.reshape([-1, 3]),
|
||||
displacements.reshape([-1, 3]),
|
||||
particle_mesh,
|
||||
chunk_size=2**24,
|
||||
val=weight)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(1, 2, 4))
|
||||
def cic_paint_dx(displacements,
|
||||
halo_size=0,
|
||||
sharding=None,
|
||||
weight=1.0,
|
||||
chunk_size=2**24):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
grid_mesh = autoshmap(partial(_cic_paint_dx_impl,
|
||||
halo_size=halo_size,
|
||||
weight=weight,
|
||||
chunk_size=chunk_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=spec,
|
||||
out_specs=spec)(displacements)
|
||||
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
grid_mesh = slice_unpad(grid_mesh, halo_size, sharding)
|
||||
return grid_mesh
|
||||
|
||||
|
||||
def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||
|
||||
halo_x, _ = halo_size[0]
|
||||
halo_y, _ = halo_size[1]
|
||||
|
||||
original_shape = [
|
||||
dim - 2 * halo[0] for dim, halo in zip(grid_mesh.shape, halo_size)
|
||||
]
|
||||
a, b, c = jnp.meshgrid(jnp.arange(original_shape[0]),
|
||||
jnp.arange(original_shape[1]),
|
||||
jnp.arange(original_shape[2]),
|
||||
indexing='ij')
|
||||
|
||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||
|
||||
pmid = pmid.reshape([-1, 3])
|
||||
disp = disp.reshape([-1, 3])
|
||||
|
||||
return gather(pmid, disp, grid_mesh).reshape(original_shape)
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2, 3))
|
||||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||
|
||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||
grid_mesh = halo_exchange(grid_mesh,
|
||||
halo_extents=halo_extents,
|
||||
halo_periods=(True, True))
|
||||
gpu_mesh = sharding.mesh if isinstance(sharding, NamedSharding) else None
|
||||
spec = sharding.spec if isinstance(sharding, NamedSharding) else P()
|
||||
displacements = autoshmap(partial(_cic_read_dx_impl, halo_size=halo_size),
|
||||
gpu_mesh=gpu_mesh,
|
||||
in_specs=(spec),
|
||||
out_specs=spec)(grid_mesh, disp)
|
||||
|
||||
return displacements
|
||||
|
||||
|
||||
def compensate_cic(field):
|
||||
"""
|
||||
Compensate for CiC painting
|
||||
Args:
|
||||
field: input 3D cic-painted field
|
||||
Returns:
|
||||
compensated_field
|
||||
"""
|
||||
nc = field.shape
|
||||
kvec = fftk(nc)
|
||||
Compensate for CiC painting
|
||||
Args:
|
||||
field: input 3D cic-painted field
|
||||
Returns:
|
||||
compensated_field
|
||||
"""
|
||||
delta_k = fft3d(field)
|
||||
|
||||
delta_k = jnp.fft.rfftn(field)
|
||||
kvec = fftk(delta_k)
|
||||
delta_k = cic_compensation(kvec) * delta_k
|
||||
return jnp.fft.irfftn(delta_k)
|
||||
return ifft3d(delta_k)
|
||||
|
|
190
jaxpm/painting_utils.py
Normal file
190
jaxpm/painting_utils.py
Normal file
|
@ -0,0 +1,190 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.lax import scan
|
||||
|
||||
|
||||
def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||
"""Split and reshape particle arrays into chunks and remainders, with the remainders
|
||||
preceding the chunks. 0D ones are duplicated as full arrays in the chunks."""
|
||||
chunk_size = ptcl_num if chunk_size is None else min(chunk_size, ptcl_num)
|
||||
remainder_size = ptcl_num % chunk_size
|
||||
chunk_num = ptcl_num // chunk_size
|
||||
|
||||
remainder = None
|
||||
chunks = arrays
|
||||
if remainder_size:
|
||||
remainder = [x[:remainder_size] if x.ndim != 0 else x for x in arrays]
|
||||
chunks = [x[remainder_size:] if x.ndim != 0 else x for x in arrays]
|
||||
|
||||
# `scan` triggers errors in scatter and gather without the `full`
|
||||
chunks = [
|
||||
x.reshape(chunk_num, chunk_size, *x.shape[1:])
|
||||
if x.ndim != 0 else jnp.full(chunk_num, x) for x in chunks
|
||||
]
|
||||
|
||||
return remainder, chunks
|
||||
|
||||
|
||||
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||
new_cell_size, new_shape):
|
||||
"""Multilinear enmeshing."""
|
||||
base_indices = jnp.asarray(base_indices)
|
||||
displacements = jnp.asarray(displacements)
|
||||
with jax.experimental.enable_x64():
|
||||
cell_size = jnp.float64(
|
||||
cell_size) if new_cell_size is not None else jnp.array(
|
||||
cell_size, dtype=displacements.dtype)
|
||||
if base_shape is not None:
|
||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||
offset = jnp.float64(offset)
|
||||
if new_cell_size is not None:
|
||||
new_cell_size = jnp.float64(new_cell_size)
|
||||
if new_shape is not None:
|
||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||
|
||||
spatial_dim = base_indices.shape[1]
|
||||
neighbor_offsets = (
|
||||
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
|
||||
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
|
||||
|
||||
if new_cell_size is not None:
|
||||
particle_positions = base_indices * cell_size + displacements - offset
|
||||
particle_positions = particle_positions[:, jnp.
|
||||
newaxis] # insert neighbor axis
|
||||
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
|
||||
|
||||
if base_shape is not None:
|
||||
grid_length = base_shape * cell_size
|
||||
new_indices %= grid_length
|
||||
|
||||
new_indices //= new_cell_size
|
||||
new_displacements = particle_positions - new_indices * new_cell_size
|
||||
|
||||
if base_shape is not None:
|
||||
new_displacements -= jnp.rint(
|
||||
new_displacements / grid_length
|
||||
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||
|
||||
new_indices = new_indices.astype(base_indices.dtype)
|
||||
new_displacements = new_displacements.astype(displacements.dtype)
|
||||
new_cell_size = new_cell_size.astype(displacements.dtype)
|
||||
|
||||
new_displacements /= new_cell_size
|
||||
else:
|
||||
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
|
||||
base_indices -= offset_indices.astype(base_indices.dtype)
|
||||
displacements -= offset_displacements.astype(displacements.dtype)
|
||||
|
||||
# insert neighbor axis
|
||||
base_indices = base_indices[:, jnp.newaxis]
|
||||
displacements = displacements[:, jnp.newaxis]
|
||||
|
||||
# multilinear
|
||||
displacements /= cell_size
|
||||
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
|
||||
new_indices += neighbor_offsets
|
||||
new_displacements = displacements - new_indices
|
||||
new_indices += base_indices
|
||||
|
||||
if base_shape is not None:
|
||||
new_indices %= base_shape
|
||||
|
||||
weights = 1 - jnp.abs(new_displacements)
|
||||
|
||||
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||
|
||||
weights = weights.prod(axis=-1)
|
||||
|
||||
return new_indices, weights
|
||||
|
||||
|
||||
def _scatter_chunk(carry, chunk):
|
||||
mesh, offset, cell_size, mesh_shape = carry
|
||||
pmid, disp, val = chunk
|
||||
spatial_ndim = pmid.shape[1]
|
||||
spatial_shape = mesh.shape
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
# scatter
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
mesh = mesh.at[ind].add(jnp.multiply(jnp.expand_dims(val, axis=-1), frac))
|
||||
carry = mesh, offset, cell_size, mesh_shape
|
||||
return carry, None
|
||||
|
||||
|
||||
def scatter(pmid,
|
||||
disp,
|
||||
mesh,
|
||||
chunk_size=2**24,
|
||||
val=1.,
|
||||
offset=0,
|
||||
cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
val = jnp.asarray(val)
|
||||
mesh = jnp.asarray(mesh)
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
if remainder is not None:
|
||||
carry = _scatter_chunk(carry, remainder)[0]
|
||||
carry = scan(_scatter_chunk, carry, chunks)[0]
|
||||
mesh = carry[0]
|
||||
return mesh
|
||||
|
||||
|
||||
def _chunk_cat(remainder_array, chunked_array):
|
||||
"""Reshape and concatenate one remainder and one chunked particle arrays."""
|
||||
array = chunked_array.reshape(-1, *chunked_array.shape[2:])
|
||||
|
||||
if remainder_array is not None:
|
||||
array = jnp.concatenate((remainder_array, array), axis=0)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||
ptcl_num, spatial_ndim = pmid.shape
|
||||
|
||||
mesh = jnp.asarray(mesh)
|
||||
|
||||
val = jnp.asarray(val)
|
||||
|
||||
if mesh.shape[spatial_ndim:] != val.shape[1:]:
|
||||
raise ValueError('channel shape mismatch: '
|
||||
f'{mesh.shape[spatial_ndim:]} != {val.shape[1:]}')
|
||||
|
||||
remainder, chunks = _chunk_split(ptcl_num, chunk_size, pmid, disp, val)
|
||||
|
||||
carry = mesh, offset, cell_size, mesh.shape
|
||||
val_0 = None
|
||||
if remainder is not None:
|
||||
val_0 = _gather_chunk(carry, remainder)[1]
|
||||
val = scan(_gather_chunk, carry, chunks)[1]
|
||||
|
||||
val = _chunk_cat(val_0, val)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
def _gather_chunk(carry, chunk):
|
||||
mesh, offset, cell_size, mesh_shape = carry
|
||||
pmid, disp, val = chunk
|
||||
|
||||
spatial_ndim = pmid.shape[1]
|
||||
|
||||
spatial_shape = mesh.shape[:spatial_ndim]
|
||||
chan_ndim = mesh.ndim - spatial_ndim
|
||||
chan_axis = tuple(range(-chan_ndim, 0))
|
||||
|
||||
# multilinear mesh indices and fractions
|
||||
ind, frac = enmesh(pmid, disp, cell_size, mesh_shape, offset, cell_size,
|
||||
spatial_shape)
|
||||
|
||||
# gather
|
||||
ind = tuple(ind[..., i] for i in range(spatial_ndim))
|
||||
frac = jnp.expand_dims(frac, chan_axis)
|
||||
val += (mesh.at[ind].get(mode='drop', fill_value=0) * frac).sum(axis=1)
|
||||
|
||||
return carry, val
|
129
jaxpm/plotting.py
Normal file
129
jaxpm/plotting.py
Normal file
|
@ -0,0 +1,129 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def plot_fields(fields_dict, sum_over=None):
|
||||
"""
|
||||
Plots sum projections of 3D fields along different axes,
|
||||
slicing only the first `sum_over` elements along each axis.
|
||||
|
||||
Args:
|
||||
- fields: list of 3D arrays representing fields to plot
|
||||
- names: list of names for each field, used in titles
|
||||
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
|
||||
"""
|
||||
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||
nb_rows = len(fields_dict)
|
||||
nb_cols = 3
|
||||
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
|
||||
|
||||
def plot_subplots(proj_axis, field, row, title):
|
||||
slicing = [slice(None)] * field.ndim
|
||||
slicing[proj_axis] = slice(None, sum_over)
|
||||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over the specified axis and plot
|
||||
axes[row, proj_axis].imshow(
|
||||
field[slicing].sum(axis=proj_axis) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
|
||||
axes[row, proj_axis].set_xlabel('Mpc/h')
|
||||
axes[row, proj_axis].set_ylabel('Mpc/h')
|
||||
axes[row, proj_axis].set_title(title)
|
||||
|
||||
# Plot each field across the three axes
|
||||
for i, (name, field) in enumerate(fields_dict.items()):
|
||||
for proj_axis in range(3):
|
||||
plot_subplots(proj_axis, field, i,
|
||||
f"{name} projection {proj_axis}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_fields_single_projection(fields_dict,
|
||||
sum_over=None,
|
||||
project_axis=0,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
colorbar=False):
|
||||
"""
|
||||
Plots a single projection (along axis 0) of 3D fields in a grid,
|
||||
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
|
||||
|
||||
Args:
|
||||
- fields_dict: dictionary where keys are field names and values are 3D arrays
|
||||
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
|
||||
"""
|
||||
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||
nb_fields = len(fields_dict)
|
||||
nb_cols = 4 # Set number of images per row
|
||||
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
|
||||
|
||||
fig, axes = plt.subplots(nb_rows,
|
||||
nb_cols,
|
||||
figsize=(5 * nb_cols, 5 * nb_rows))
|
||||
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
|
||||
|
||||
for i, (name, field) in enumerate(fields_dict.items()):
|
||||
row, col = divmod(i, nb_cols)
|
||||
|
||||
# Define the slice for the 0-axis projection
|
||||
slicing = [slice(None)] * field.ndim
|
||||
slicing[project_axis] = slice(None, sum_over)
|
||||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over axis 0 and plot
|
||||
a = axes[row,
|
||||
col].imshow(field[slicing].sum(axis=project_axis) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[1], 0, field.shape[2]],
|
||||
vmin=vmin,
|
||||
vmax=vmax)
|
||||
axes[row, col].set_xlabel('Mpc/h')
|
||||
axes[row, col].set_ylabel('Mpc/h')
|
||||
axes[row, col].set_title(f"{name} projection 0")
|
||||
if colorbar:
|
||||
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
|
||||
|
||||
# Remove any empty subplots
|
||||
for j in range(i + 1, nb_rows * nb_cols):
|
||||
fig.delaxes(axes.flatten()[j])
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def stack_slices(array):
|
||||
"""
|
||||
Stacks 2D slices of an array into a single array based on provided partition dimensions.
|
||||
|
||||
Args:
|
||||
- array_slices: a 2D list of array slices (list of lists format) where
|
||||
array_slices[i][j] is the slice located at row i, column j in the grid.
|
||||
- pdims: a tuple representing the grid dimensions (rows, columns).
|
||||
|
||||
Returns:
|
||||
- A single array constructed by stacking the slices.
|
||||
"""
|
||||
# Initialize an empty list to store the vertically stacked rows
|
||||
pdims = array.sharding.mesh.devices.shape
|
||||
|
||||
field_slices = []
|
||||
|
||||
# Iterate over rows in pdims[0]
|
||||
for i in range(pdims[0]):
|
||||
row_slices = []
|
||||
|
||||
# Iterate over columns in pdims[1]
|
||||
for j in range(pdims[1]):
|
||||
slice_index = i * pdims[0] + j
|
||||
row_slices.append(array.addressable_data(slice_index))
|
||||
# Stack the current row of slices vertically
|
||||
stacked_row = np.hstack(row_slices)
|
||||
field_slices.append(stacked_row)
|
||||
|
||||
# Stack all rows horizontally to form the full array
|
||||
full_array = np.vstack(field_slices)
|
||||
|
||||
return full_array
|
206
jaxpm/pm.py
206
jaxpm/pm.py
|
@ -1,50 +1,92 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax_cosmo import Cosmology
|
||||
|
||||
from jaxpm.growth import growth_factor, growth_rate, dGfa, growth_factor_second, growth_rate_second, dGf2a
|
||||
from jaxpm.kernels import PGD_kernel, fftk, gradient_kernel, invlaplace_kernel, longrange_kernel
|
||||
from jaxpm.painting import cic_paint, cic_read
|
||||
from jaxpm.distributed import fft3d, ifft3d, normal_field
|
||||
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||
growth_rate, growth_rate_second)
|
||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||
invlaplace_kernel, longrange_kernel)
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||
|
||||
|
||||
|
||||
def pm_forces(positions, mesh_shape, delta=None, r_split=0):
|
||||
def pm_forces(positions,
|
||||
mesh_shape=None,
|
||||
delta=None,
|
||||
r_split=0,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
if mesh_shape is None:
|
||||
assert (delta is not None),\
|
||||
"If mesh_shape is not provided, delta should be provided"
|
||||
mesh_shape = delta.shape
|
||||
|
||||
if paint_absolute_pos:
|
||||
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
|
||||
device=sharding),
|
||||
pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
read_fn = lambda grid_mesh, pos: cic_read(
|
||||
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
|
||||
else:
|
||||
paint_fn = lambda disp: cic_paint_dx(
|
||||
disp, halo_size=halo_size, sharding=sharding)
|
||||
read_fn = lambda grid_mesh, disp: cic_read_dx(
|
||||
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
|
||||
|
||||
if delta is None:
|
||||
delta_k = jnp.fft.rfftn(cic_paint(jnp.zeros(mesh_shape), positions))
|
||||
field = paint_fn(positions)
|
||||
delta_k = fft3d(field)
|
||||
elif jnp.isrealobj(delta):
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
delta_k = fft3d(delta)
|
||||
else:
|
||||
delta_k = delta
|
||||
|
||||
kvec = fftk(delta_k)
|
||||
# Computes gravitational potential
|
||||
kvec = fftk(mesh_shape)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(
|
||||
kvec, r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
return jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i) * pot_k), positions)
|
||||
for i in range(3)], axis=-1)
|
||||
forces = jnp.stack([
|
||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||
) for i in range(3)], axis=-1) # yapf: disable
|
||||
|
||||
return forces
|
||||
|
||||
|
||||
def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
||||
def lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles=None,
|
||||
a=0.1,
|
||||
halo_size=0,
|
||||
sharding=None,
|
||||
order=1):
|
||||
"""
|
||||
Computes first and second order LPT displacement and momentum,
|
||||
Computes first and second order LPT displacement and momentum,
|
||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||
"""
|
||||
paint_absolute_pos = particles is not None
|
||||
if particles is None:
|
||||
particles = jnp.zeros_like(initial_conditions,
|
||||
shape=(*initial_conditions.shape, 3))
|
||||
|
||||
a = jnp.atleast_1d(a)
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
delta_k = jnp.fft.rfftn(init_mesh) # TODO: pass the modes directly to save one or two fft?
|
||||
mesh_shape = init_mesh.shape
|
||||
|
||||
init_force = pm_forces(positions, mesh_shape, delta=delta_k)
|
||||
dx = growth_factor(cosmo, a) * init_force
|
||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||
delta_k = fft3d(initial_conditions)
|
||||
initial_force = pm_forces(particles,
|
||||
delta=delta_k,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
p = a**2 * growth_rate(cosmo, a) * E * dx
|
||||
f = a**2 * E * dGfa(cosmo, a) * init_force
|
||||
|
||||
f = a**2 * E * dGfa(cosmo, a) * initial_force
|
||||
if order == 2:
|
||||
kvec = fftk(mesh_shape)
|
||||
kvec = fftk(delta_k)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec)
|
||||
|
||||
delta2 = 0
|
||||
|
@ -54,47 +96,58 @@ def lpt(cosmo:Cosmology, init_mesh, positions, a, order=1):
|
|||
# Add products of diagonal terms = 0 + s11*s00 + s22*(s11+s00)...
|
||||
# shear_ii = jnp.fft.irfftn(- ki**2 * pot_k)
|
||||
nabla_i_nabla_i = gradient_kernel(kvec, i)**2
|
||||
shear_ii = jnp.fft.irfftn(nabla_i_nabla_i * pot_k)
|
||||
delta2 += shear_ii * shear_acc
|
||||
shear_ii = ifft3d(nabla_i_nabla_i * pot_k)
|
||||
delta2 += shear_ii * shear_acc
|
||||
shear_acc += shear_ii
|
||||
|
||||
# for kj in kvec[i+1:]:
|
||||
for j in range(i+1, 3):
|
||||
for j in range(i + 1, 3):
|
||||
# Substract squared strict-up-triangle terms
|
||||
# delta2 -= jnp.fft.irfftn(- ki * kj * pot_k)**2
|
||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(kvec, j)
|
||||
delta2 -= jnp.fft.irfftn(nabla_i_nabla_j * pot_k)**2
|
||||
nabla_i_nabla_j = gradient_kernel(kvec, i) * gradient_kernel(
|
||||
kvec, j)
|
||||
delta2 -= ifft3d(nabla_i_nabla_j * pot_k)**2
|
||||
|
||||
init_force2 = pm_forces(positions, mesh_shape, delta=jnp.fft.rfftn(delta2))
|
||||
delta_k2 = fft3d(delta2)
|
||||
init_force2 = pm_forces(particles,
|
||||
delta=delta_k2,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
# NOTE: growth_factor_second is renormalized: - D2 = 3/7 * growth_factor_second
|
||||
dx2 = 3/7 * growth_factor_second(cosmo, a) * init_force2
|
||||
dx2 = 3 / 7 * growth_factor_second(cosmo, a) * init_force2
|
||||
p2 = a**2 * growth_rate_second(cosmo, a) * E * dx2
|
||||
f2 = a**2 * E * dGf2a(cosmo, a) * init_force2
|
||||
|
||||
dx += dx2
|
||||
p += p2
|
||||
f += f2
|
||||
p += p2
|
||||
f += f2
|
||||
|
||||
return dx, p, f
|
||||
|
||||
|
||||
def linear_field(mesh_shape, box_size, pk, seed):
|
||||
def linear_field(mesh_shape, box_size, pk, seed, sharding=None):
|
||||
"""
|
||||
Generate initial conditions.
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
# Initialize a random field with one slice on each gpu
|
||||
field = normal_field(mesh_shape, seed=seed, sharding=sharding)
|
||||
field = fft3d(field)
|
||||
kvec = fftk(field)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2
|
||||
for i, kk in enumerate(kvec))**0.5
|
||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (
|
||||
box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
field = jax.random.normal(seed, mesh_shape)
|
||||
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
||||
field = jnp.fft.irfftn(field)
|
||||
field = field * (pkmesh)**0.5
|
||||
field = ifft3d(field)
|
||||
return field
|
||||
|
||||
|
||||
def make_ode_fn(mesh_shape):
|
||||
def make_ode_fn(mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
"""
|
||||
|
@ -102,7 +155,11 @@ def make_ode_fn(mesh_shape):
|
|||
"""
|
||||
pos, vel = state
|
||||
|
||||
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
@ -114,20 +171,28 @@ def make_ode_fn(mesh_shape):
|
|||
|
||||
return nbody_ode
|
||||
|
||||
def get_ode_fn(cosmo:Cosmology, mesh_shape):
|
||||
|
||||
def make_diffrax_ode(cosmo,
|
||||
mesh_shape,
|
||||
paint_absolute_pos=True,
|
||||
halo_size=0,
|
||||
sharding=None):
|
||||
|
||||
def nbody_ode(a, state, args):
|
||||
"""
|
||||
State is an array [position, velocities]
|
||||
|
||||
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
forces = pm_forces(pos, mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
|
||||
forces = pm_forces(pos,
|
||||
mesh_shape=mesh_shape,
|
||||
paint_absolute_pos=paint_absolute_pos,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
||||
|
||||
# Computes the update of velocity (kick)
|
||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
|
@ -138,51 +203,57 @@ def get_ode_fn(cosmo:Cosmology, mesh_shape):
|
|||
|
||||
def pgd_correction(pos, mesh_shape, params):
|
||||
"""
|
||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
|
||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
|
||||
based on https://arxiv.org/abs/1804.00671
|
||||
|
||||
args:
|
||||
pos: particle positions [npart, 3]
|
||||
params: [alpha, kl, ks] pgd parameters
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
delta_k = fft3d(delta)
|
||||
kvec = fftk(delta_k)
|
||||
alpha, kl, ks = params
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
PGD_range=PGD_kernel(kvec, kl, ks)
|
||||
|
||||
pot_k_pgd=(delta_k * invlaplace_kernel(kvec))*PGD_range
|
||||
PGD_range = PGD_kernel(kvec, kl, ks)
|
||||
|
||||
pot_k_pgd = (delta_k * invlaplace_kernel(kvec)) * PGD_range
|
||||
|
||||
forces_pgd = jnp.stack([
|
||||
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k_pgd), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd * alpha
|
||||
|
||||
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd*alpha
|
||||
|
||||
return dpos_pgd
|
||||
|
||||
|
||||
def make_neural_ode_fn(model, mesh_shape):
|
||||
def neural_nbody_ode(state, a, cosmo:Cosmology, params):
|
||||
|
||||
def neural_nbody_ode(state, a, cosmo: Cosmology, params):
|
||||
"""
|
||||
state is a tuple (position, velocities)
|
||||
"""
|
||||
pos, vel = state
|
||||
kvec = fftk(mesh_shape)
|
||||
|
||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
delta_k = fft3d(delta)
|
||||
kvec = fftk(delta_k)
|
||||
|
||||
# Computes gravitational potential
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)
|
||||
pot_k = delta_k * invlaplace_kernel(kvec) * longrange_kernel(kvec,
|
||||
r_split=0)
|
||||
|
||||
# Apply a correction filter
|
||||
kk = jnp.sqrt(sum((ki/jnp.pi)**2 for ki in kvec))
|
||||
pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||
kk = jnp.sqrt(sum((ki / jnp.pi)**2 for ki in kvec))
|
||||
pot_k = pot_k * (1. + model.apply(params, kk, jnp.atleast_1d(a)))
|
||||
|
||||
# Computes gravitational forces
|
||||
forces = jnp.stack([cic_read(jnp.fft.irfftn(- gradient_kernel(kvec, i)*pot_k), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
forces = jnp.stack([
|
||||
cic_read(fft3d(-gradient_kernel(kvec, i) * pot_k), pos)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
forces = forces * 1.5 * cosmo.Omega_m
|
||||
|
||||
|
@ -193,4 +264,5 @@ def make_neural_ode_fn(model, mesh_shape):
|
|||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||
|
||||
return dpos, dvel
|
||||
|
||||
return neural_nbody_ode
|
||||
|
|
227
jaxpm/utils.py
227
jaxpm/utils.py
|
@ -1,47 +1,168 @@
|
|||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.scipy.stats import norm
|
||||
from scipy.special import legendre
|
||||
|
||||
__all__ = ['power_spectrum']
|
||||
__all__ = [
|
||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||
]
|
||||
|
||||
|
||||
def _initialize_pk(shape, boxsize, kmin, dk):
|
||||
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
||||
"""
|
||||
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
|
||||
Parameters
|
||||
----------
|
||||
mesh_shape : tuple of int
|
||||
Shape of the mesh grid.
|
||||
box_shape : tuple of float
|
||||
Physical dimensions of the box.
|
||||
kedges : None, int, float, or list
|
||||
If None, set dk to twice the minimum.
|
||||
If int, specifies number of edges.
|
||||
If float, specifies dk.
|
||||
los : array_like
|
||||
Line-of-sight vector.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dig : ndarray
|
||||
Indices of the bins to which each value in input array belongs.
|
||||
kcount : ndarray
|
||||
Count of values in each bin.
|
||||
kedges : ndarray
|
||||
Edges of the bins.
|
||||
mumesh : ndarray
|
||||
Mu values for the mesh grid.
|
||||
"""
|
||||
I = np.eye(len(shape), dtype='int') * -2 + 1
|
||||
kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
|
||||
|
||||
W = np.empty(shape, dtype='f4')
|
||||
W[...] = 2.0
|
||||
W[..., 0] = 1.0
|
||||
W[..., -1] = 1.0
|
||||
if isinstance(kedges, None | int | float):
|
||||
if kedges is None:
|
||||
dk = 2 * np.pi / np.min(
|
||||
box_shape) * 2 # twice the minimum wavenumber
|
||||
if isinstance(kedges, int):
|
||||
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
|
||||
elif isinstance(kedges, float):
|
||||
dk = kedges
|
||||
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
|
||||
|
||||
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
|
||||
kedges = np.arange(kmin, kmax, dk)
|
||||
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
|
||||
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
|
||||
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
|
||||
kmesh = sum(ki**2 for ki in kvec)**0.5
|
||||
|
||||
k = [
|
||||
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
|
||||
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
|
||||
]
|
||||
kmag = sum(ki**2 for ki in k)**0.5
|
||||
dig = np.digitize(kmesh.reshape(-1), kedges)
|
||||
kcount = np.bincount(dig, minlength=len(kedges) + 1)
|
||||
|
||||
xsum = np.zeros(len(kedges) + 1)
|
||||
Nsum = np.zeros(len(kedges) + 1)
|
||||
# Central value of each bin
|
||||
# kavg = (kedges[1:] + kedges[:-1]) / 2
|
||||
kavg = np.bincount(
|
||||
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
|
||||
kavg = kavg[1:-1]
|
||||
|
||||
dig = np.digitize(kmag.flat, kedges)
|
||||
if los is None:
|
||||
mumesh = 1.
|
||||
else:
|
||||
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
|
||||
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
|
||||
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
|
||||
|
||||
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
|
||||
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
|
||||
return dig, Nsum, xsum, W, k, kedges
|
||||
return dig, kcount, kavg, mumesh
|
||||
|
||||
|
||||
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
||||
def power_spectrum(mesh,
|
||||
mesh2=None,
|
||||
box_shape=None,
|
||||
kedges: int | float | list = None,
|
||||
multipoles=0,
|
||||
los=[0., 0., 1.]):
|
||||
"""
|
||||
Calculate the powerspectra given real space field
|
||||
Compute the auto and cross spectrum of 3D fields, with multipoles.
|
||||
"""
|
||||
# Initialize
|
||||
mesh_shape = np.array(mesh.shape)
|
||||
if box_shape is None:
|
||||
box_shape = mesh_shape
|
||||
else:
|
||||
box_shape = np.asarray(box_shape)
|
||||
|
||||
if multipoles == 0:
|
||||
los = None
|
||||
else:
|
||||
los = np.asarray(los)
|
||||
los = los / np.linalg.norm(los)
|
||||
poles = np.atleast_1d(multipoles)
|
||||
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
|
||||
los)
|
||||
n_bins = len(kavg) + 2
|
||||
|
||||
# FFTs
|
||||
meshk = jnp.fft.fftn(mesh, norm='ortho')
|
||||
if mesh2 is None:
|
||||
mmk = meshk.real**2 + meshk.imag**2
|
||||
else:
|
||||
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
|
||||
|
||||
# Sum powers
|
||||
pk = jnp.empty((len(poles), n_bins))
|
||||
for i_ell, ell in enumerate(poles):
|
||||
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
|
||||
if mesh2 is None:
|
||||
psum = jnp.bincount(dig, weights=weights, length=n_bins)
|
||||
else: # XXX: bincount is really slow with complex numbers
|
||||
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
|
||||
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
|
||||
psum = (psum_real**2 + psum_imag**2)**.5
|
||||
pk = pk.at[i_ell].set(psum)
|
||||
|
||||
# Normalization and conversion from cell units to [Mpc/h]^3
|
||||
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
|
||||
|
||||
# pk = jnp.concatenate([kavg[None], pk])
|
||||
if np.ndim(multipoles) == 0:
|
||||
return kavg, pk[0]
|
||||
else:
|
||||
return kavg, pk
|
||||
|
||||
|
||||
def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, (pk1 / pk0)**.5
|
||||
|
||||
|
||||
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, pk01 / (pk0 * pk1)**.5
|
||||
|
||||
|
||||
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
|
||||
|
||||
|
||||
def cross_correlation_coefficients(field_a,
|
||||
field_b,
|
||||
kmin=5,
|
||||
dk=0.5,
|
||||
boxsize=False):
|
||||
"""
|
||||
Calculate the cross correlation coefficients given two real space field
|
||||
|
||||
Args:
|
||||
|
||||
field: real valued field
|
||||
field_a: real valued field
|
||||
field_b: real valued field
|
||||
kmin: minimum k-value for binned powerspectra
|
||||
dk: differential in each kbin
|
||||
boxsize: length of each boxlength (can be strangly shaped?)
|
||||
|
@ -49,20 +170,21 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
|||
Returns:
|
||||
|
||||
kbins: the central value of the bins for plotting
|
||||
power: real valued array of power in each bin
|
||||
P / norm: normalized cross correlation coefficient between two field a and b
|
||||
|
||||
"""
|
||||
shape = field.shape
|
||||
shape = field_a.shape
|
||||
nx, ny, nz = shape
|
||||
|
||||
#initialze values related to powerspectra (mode bins and weights)
|
||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||
|
||||
#fast fourier transform
|
||||
fft_image = jnp.fft.fftn(field)
|
||||
fft_image_a = jnp.fft.fftn(field_a)
|
||||
fft_image_b = jnp.fft.fftn(field_b)
|
||||
|
||||
#absolute value of fast fourier transform
|
||||
pk = jnp.real(fft_image * jnp.conj(fft_image))
|
||||
pk = fft_image_a * jnp.conj(fft_image_b)
|
||||
|
||||
#calculating powerspectra
|
||||
real = jnp.real(pk).reshape([-1])
|
||||
|
@ -83,55 +205,6 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
|||
return kbins, P / norm
|
||||
|
||||
|
||||
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
|
||||
"""
|
||||
Calculate the cross correlation coefficients given two real space field
|
||||
|
||||
Args:
|
||||
|
||||
field_a: real valued field
|
||||
field_b: real valued field
|
||||
kmin: minimum k-value for binned powerspectra
|
||||
dk: differential in each kbin
|
||||
boxsize: length of each boxlength (can be strangly shaped?)
|
||||
|
||||
Returns:
|
||||
|
||||
kbins: the central value of the bins for plotting
|
||||
P / norm: normalized cross correlation coefficient between two field a and b
|
||||
|
||||
"""
|
||||
shape = field_a.shape
|
||||
nx, ny, nz = shape
|
||||
|
||||
#initialze values related to powerspectra (mode bins and weights)
|
||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||
|
||||
#fast fourier transform
|
||||
fft_image_a = jnp.fft.fftn(field_a)
|
||||
fft_image_b = jnp.fft.fftn(field_b)
|
||||
|
||||
#absolute value of fast fourier transform
|
||||
pk = fft_image_a * jnp.conj(fft_image_b)
|
||||
|
||||
#calculating powerspectra
|
||||
real = jnp.real(pk).reshape([-1])
|
||||
imag = jnp.imag(pk).reshape([-1])
|
||||
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
|
||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||
|
||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||
|
||||
#normalization for powerspectra
|
||||
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
||||
|
||||
#find central values of each bin
|
||||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||
|
||||
return kbins, P / norm
|
||||
|
||||
|
||||
def gaussian_smoothing(im, sigma):
|
||||
"""
|
||||
im: 2d image
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue