mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-14 03:51:11 +00:00
jaxdecomp proto (#21)
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
This commit is contained in:
parent
a0a79277e5
commit
df8602b318
26 changed files with 1871 additions and 434 deletions
227
jaxpm/utils.py
227
jaxpm/utils.py
|
@ -1,47 +1,168 @@
|
|||
from functools import partial
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.scipy.stats import norm
|
||||
from scipy.special import legendre
|
||||
|
||||
__all__ = ['power_spectrum']
|
||||
__all__ = [
|
||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||
]
|
||||
|
||||
|
||||
def _initialize_pk(shape, boxsize, kmin, dk):
|
||||
def _initialize_pk(mesh_shape, box_shape, kedges, los):
|
||||
"""
|
||||
Helper function to initialize various (fixed) values for powerspectra... not differentiable!
|
||||
Parameters
|
||||
----------
|
||||
mesh_shape : tuple of int
|
||||
Shape of the mesh grid.
|
||||
box_shape : tuple of float
|
||||
Physical dimensions of the box.
|
||||
kedges : None, int, float, or list
|
||||
If None, set dk to twice the minimum.
|
||||
If int, specifies number of edges.
|
||||
If float, specifies dk.
|
||||
los : array_like
|
||||
Line-of-sight vector.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dig : ndarray
|
||||
Indices of the bins to which each value in input array belongs.
|
||||
kcount : ndarray
|
||||
Count of values in each bin.
|
||||
kedges : ndarray
|
||||
Edges of the bins.
|
||||
mumesh : ndarray
|
||||
Mu values for the mesh grid.
|
||||
"""
|
||||
I = np.eye(len(shape), dtype='int') * -2 + 1
|
||||
kmax = np.pi * np.min(mesh_shape / box_shape) # = knyquist
|
||||
|
||||
W = np.empty(shape, dtype='f4')
|
||||
W[...] = 2.0
|
||||
W[..., 0] = 1.0
|
||||
W[..., -1] = 1.0
|
||||
if isinstance(kedges, None | int | float):
|
||||
if kedges is None:
|
||||
dk = 2 * np.pi / np.min(
|
||||
box_shape) * 2 # twice the minimum wavenumber
|
||||
if isinstance(kedges, int):
|
||||
dk = kmax / (kedges + 1) # final number of bins will be kedges-1
|
||||
elif isinstance(kedges, float):
|
||||
dk = kedges
|
||||
kedges = np.arange(dk, kmax, dk) + dk / 2 # from dk/2 to kmax-dk/2
|
||||
|
||||
kmax = np.pi * np.min(np.array(shape)) / np.max(np.array(boxsize)) + dk / 2
|
||||
kedges = np.arange(kmin, kmax, dk)
|
||||
kshapes = np.eye(len(mesh_shape), dtype=np.int32) * -2 + 1
|
||||
kvec = [(2 * np.pi * m / l) * np.fft.fftfreq(m).reshape(kshape)
|
||||
for m, l, kshape in zip(mesh_shape, box_shape, kshapes)]
|
||||
kmesh = sum(ki**2 for ki in kvec)**0.5
|
||||
|
||||
k = [
|
||||
np.fft.fftfreq(N, 1. / (N * 2 * np.pi / L))[:pkshape].reshape(kshape)
|
||||
for N, L, kshape, pkshape in zip(shape, boxsize, I, shape)
|
||||
]
|
||||
kmag = sum(ki**2 for ki in k)**0.5
|
||||
dig = np.digitize(kmesh.reshape(-1), kedges)
|
||||
kcount = np.bincount(dig, minlength=len(kedges) + 1)
|
||||
|
||||
xsum = np.zeros(len(kedges) + 1)
|
||||
Nsum = np.zeros(len(kedges) + 1)
|
||||
# Central value of each bin
|
||||
# kavg = (kedges[1:] + kedges[:-1]) / 2
|
||||
kavg = np.bincount(
|
||||
dig, weights=kmesh.reshape(-1), minlength=len(kedges) + 1) / kcount
|
||||
kavg = kavg[1:-1]
|
||||
|
||||
dig = np.digitize(kmag.flat, kedges)
|
||||
if los is None:
|
||||
mumesh = 1.
|
||||
else:
|
||||
mumesh = sum(ki * losi for ki, losi in zip(kvec, los))
|
||||
kmesh_nozeros = np.where(kmesh == 0, 1, kmesh)
|
||||
mumesh = np.where(kmesh == 0, 0, mumesh / kmesh_nozeros)
|
||||
|
||||
xsum.flat += np.bincount(dig, weights=(W * kmag).flat, minlength=xsum.size)
|
||||
Nsum.flat += np.bincount(dig, weights=W.flat, minlength=xsum.size)
|
||||
return dig, Nsum, xsum, W, k, kedges
|
||||
return dig, kcount, kavg, mumesh
|
||||
|
||||
|
||||
def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
||||
def power_spectrum(mesh,
|
||||
mesh2=None,
|
||||
box_shape=None,
|
||||
kedges: int | float | list = None,
|
||||
multipoles=0,
|
||||
los=[0., 0., 1.]):
|
||||
"""
|
||||
Calculate the powerspectra given real space field
|
||||
Compute the auto and cross spectrum of 3D fields, with multipoles.
|
||||
"""
|
||||
# Initialize
|
||||
mesh_shape = np.array(mesh.shape)
|
||||
if box_shape is None:
|
||||
box_shape = mesh_shape
|
||||
else:
|
||||
box_shape = np.asarray(box_shape)
|
||||
|
||||
if multipoles == 0:
|
||||
los = None
|
||||
else:
|
||||
los = np.asarray(los)
|
||||
los = los / np.linalg.norm(los)
|
||||
poles = np.atleast_1d(multipoles)
|
||||
dig, kcount, kavg, mumesh = _initialize_pk(mesh_shape, box_shape, kedges,
|
||||
los)
|
||||
n_bins = len(kavg) + 2
|
||||
|
||||
# FFTs
|
||||
meshk = jnp.fft.fftn(mesh, norm='ortho')
|
||||
if mesh2 is None:
|
||||
mmk = meshk.real**2 + meshk.imag**2
|
||||
else:
|
||||
mmk = meshk * jnp.fft.fftn(mesh2, norm='ortho').conj()
|
||||
|
||||
# Sum powers
|
||||
pk = jnp.empty((len(poles), n_bins))
|
||||
for i_ell, ell in enumerate(poles):
|
||||
weights = (mmk * (2 * ell + 1) * legendre(ell)(mumesh)).reshape(-1)
|
||||
if mesh2 is None:
|
||||
psum = jnp.bincount(dig, weights=weights, length=n_bins)
|
||||
else: # XXX: bincount is really slow with complex numbers
|
||||
psum_real = jnp.bincount(dig, weights=weights.real, length=n_bins)
|
||||
psum_imag = jnp.bincount(dig, weights=weights.imag, length=n_bins)
|
||||
psum = (psum_real**2 + psum_imag**2)**.5
|
||||
pk = pk.at[i_ell].set(psum)
|
||||
|
||||
# Normalization and conversion from cell units to [Mpc/h]^3
|
||||
pk = (pk / kcount)[:, 1:-1] * (box_shape / mesh_shape).prod()
|
||||
|
||||
# pk = jnp.concatenate([kavg[None], pk])
|
||||
if np.ndim(multipoles) == 0:
|
||||
return kavg, pk[0]
|
||||
else:
|
||||
return kavg, pk
|
||||
|
||||
|
||||
def transfer(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, (pk1 / pk0)**.5
|
||||
|
||||
|
||||
def coherence(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, pk01 / (pk0 * pk1)**.5
|
||||
|
||||
|
||||
def pktranscoh(mesh0, mesh1, box_shape, kedges: int | float | list = None):
|
||||
pk_fn = partial(power_spectrum, box_shape=box_shape, kedges=kedges)
|
||||
ks, pk01 = pk_fn(mesh0, mesh1)
|
||||
ks, pk0 = pk_fn(mesh0)
|
||||
ks, pk1 = pk_fn(mesh1)
|
||||
return ks, pk0, pk1, (pk1 / pk0)**.5, pk01 / (pk0 * pk1)**.5
|
||||
|
||||
|
||||
def cross_correlation_coefficients(field_a,
|
||||
field_b,
|
||||
kmin=5,
|
||||
dk=0.5,
|
||||
boxsize=False):
|
||||
"""
|
||||
Calculate the cross correlation coefficients given two real space field
|
||||
|
||||
Args:
|
||||
|
||||
field: real valued field
|
||||
field_a: real valued field
|
||||
field_b: real valued field
|
||||
kmin: minimum k-value for binned powerspectra
|
||||
dk: differential in each kbin
|
||||
boxsize: length of each boxlength (can be strangly shaped?)
|
||||
|
@ -49,20 +170,21 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
|||
Returns:
|
||||
|
||||
kbins: the central value of the bins for plotting
|
||||
power: real valued array of power in each bin
|
||||
P / norm: normalized cross correlation coefficient between two field a and b
|
||||
|
||||
"""
|
||||
shape = field.shape
|
||||
shape = field_a.shape
|
||||
nx, ny, nz = shape
|
||||
|
||||
#initialze values related to powerspectra (mode bins and weights)
|
||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||
|
||||
#fast fourier transform
|
||||
fft_image = jnp.fft.fftn(field)
|
||||
fft_image_a = jnp.fft.fftn(field_a)
|
||||
fft_image_b = jnp.fft.fftn(field_b)
|
||||
|
||||
#absolute value of fast fourier transform
|
||||
pk = jnp.real(fft_image * jnp.conj(fft_image))
|
||||
pk = fft_image_a * jnp.conj(fft_image_b)
|
||||
|
||||
#calculating powerspectra
|
||||
real = jnp.real(pk).reshape([-1])
|
||||
|
@ -83,55 +205,6 @@ def power_spectrum(field, kmin=5, dk=0.5, boxsize=False):
|
|||
return kbins, P / norm
|
||||
|
||||
|
||||
def cross_correlation_coefficients(field_a,field_b, kmin=5, dk=0.5, boxsize=False):
|
||||
"""
|
||||
Calculate the cross correlation coefficients given two real space field
|
||||
|
||||
Args:
|
||||
|
||||
field_a: real valued field
|
||||
field_b: real valued field
|
||||
kmin: minimum k-value for binned powerspectra
|
||||
dk: differential in each kbin
|
||||
boxsize: length of each boxlength (can be strangly shaped?)
|
||||
|
||||
Returns:
|
||||
|
||||
kbins: the central value of the bins for plotting
|
||||
P / norm: normalized cross correlation coefficient between two field a and b
|
||||
|
||||
"""
|
||||
shape = field_a.shape
|
||||
nx, ny, nz = shape
|
||||
|
||||
#initialze values related to powerspectra (mode bins and weights)
|
||||
dig, Nsum, xsum, W, k, kedges = _initialize_pk(shape, boxsize, kmin, dk)
|
||||
|
||||
#fast fourier transform
|
||||
fft_image_a = jnp.fft.fftn(field_a)
|
||||
fft_image_b = jnp.fft.fftn(field_b)
|
||||
|
||||
#absolute value of fast fourier transform
|
||||
pk = fft_image_a * jnp.conj(fft_image_b)
|
||||
|
||||
#calculating powerspectra
|
||||
real = jnp.real(pk).reshape([-1])
|
||||
imag = jnp.imag(pk).reshape([-1])
|
||||
|
||||
Psum = jnp.bincount(dig, weights=(W.flatten() * imag), length=xsum.size) * 1j
|
||||
Psum += jnp.bincount(dig, weights=(W.flatten() * real), length=xsum.size)
|
||||
|
||||
P = ((Psum / Nsum)[1:-1] * boxsize.prod()).astype('float32')
|
||||
|
||||
#normalization for powerspectra
|
||||
norm = np.prod(np.array(shape[:])).astype('float32')**2
|
||||
|
||||
#find central values of each bin
|
||||
kbins = kedges[:-1] + (kedges[1:] - kedges[:-1]) / 2
|
||||
|
||||
return kbins, P / norm
|
||||
|
||||
|
||||
def gaussian_smoothing(im, sigma):
|
||||
"""
|
||||
im: 2d image
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue