jaxdecomp proto (#21)

* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
This commit is contained in:
Wassim KABALAN 2024-12-20 11:44:02 +01:00 committed by GitHub
parent a0a79277e5
commit df8602b318
26 changed files with 1871 additions and 434 deletions

View file

@ -1,30 +1,46 @@
import jax.numpy as jnp
import numpy as np
from jax.lax import FftType
from jax.sharding import PartitionSpec as P
from jaxdecomp import fftfreq3d, get_output_specs
from jaxpm.distributed import autoshmap
def fftk(shape, symmetric=True, finite=False, dtype=np.float32):
"""
Return wave-vectors for a given shape
def fftk(k_array):
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kdshape = np.ones(len(shape), dtype='int')
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
kdshape[d] = len(kd)
kd = kd.reshape(kdshape)
Generate Fourier transform wave numbers for a given mesh.
k.append(kd.astype(dtype))
del kd, kdshape
return k
Args:
nc (int): Shape of the mesh grid.
Returns:
list: List of wave number arrays for each dimension in
the order [kx, ky, kz].
"""
kx, ky, kz = fftfreq3d(k_array)
# to the order of dimensions in the transposed FFT
return kx, ky, kz
def interpolate_power_spectrum(input, k, pk, sharding=None):
pk_fn = lambda x: jnp.interp(x.reshape(-1), k, pk).reshape(x.shape)
gpu_mesh = sharding.mesh if sharding is not None else None
specs = sharding.spec if sharding is not None else P()
out_specs = P(*get_output_specs(
FftType.FFT, specs, mesh=gpu_mesh)) if gpu_mesh is not None else P()
return autoshmap(pk_fn,
gpu_mesh=gpu_mesh,
in_specs=out_specs,
out_specs=out_specs)(input)
def gradient_kernel(kvec, direction, order=1):
"""
Computes the gradient kernel in the requested direction
Parameters
-----------
kvec: list
@ -50,23 +66,30 @@ def gradient_kernel(kvec, direction, order=1):
return wts
def invlaplace_kernel(kvec):
def invlaplace_kernel(kvec, fd=False):
"""
Compute the inverse Laplace kernel
Compute the inverse Laplace kernel.
cf. [Feng+2016](https://arxiv.org/pdf/1603.00476)
Parameters
-----------
kvec: list
List of wave-vectors
fd: bool
Finite difference kernel
Returns
--------
wts: array
Complex kernel values
"""
kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk==0, 1, kk)
return - jnp.where(kk==0, 0, 1 / kk_nozeros)
if fd:
kk = sum((ki * jnp.sinc(ki / (2 * jnp.pi)))**2 for ki in kvec)
else:
kk = sum(ki**2 for ki in kvec)
kk_nozeros = jnp.where(kk == 0, 1, kk)
return -jnp.where(kk == 0, 0, 1 / kk_nozeros)
def longrange_kernel(kvec, r_split):
@ -79,12 +102,10 @@ def longrange_kernel(kvec, r_split):
List of wave-vectors
r_split: float
Splitting radius
Returns
--------
wts: array
Complex kernel values
TODO: @modichirag add documentation
"""
if r_split != 0:
@ -105,13 +126,12 @@ def cic_compensation(kvec):
-----------
kvec: list
List of wave-vectors
Returns:
--------
wts: array
Complex kernel values
"""
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
kwts = [jnp.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts