* Use cosmo as arg for the ODE function
* Update examples
* format
* notebook update
* fix tests
* add correct annotations for weights in painting and warning for cic_paint in distributed pm
* update test_against_fpm
* update distributed tests and add jacfwd jacrev and vmap tests
* format
* add Caveats to notebook readme
* final touches
* update Growth.py to allow using FastPM solver
* fix 2D painting when input is (X , Y , 2) shape
* update cic read halo size and notebooks examples
* Allow env variable control of caching in growth
* Format
* update test jax version
* update notebooks/03-MultiGPU_PM_Halo.ipynb
* update numpy install in wf
* update tolerance :)
* reorganize install in test workflow
* update tests
* add mpi4py
* update tests.yml
* update tests
* update wf
* format
* make normal_field signature consistent with jax.random.normal
* update by default normal_field dtype to match JAX
* format
* debug test workflow
* format
* debug test workflow
* updating tests
* fix accuracy
* fixed tolerance
* adding caching
* Update conftest.py
* Update tolerance and precision settings in distributed PM tests
* revererting back changes to growth.py
---------
Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
* adding example of distributed solution
* put back old functgion
* update formatting
* add halo exchange and slice pad
* apply formatting
* implement distributed optimized cic_paint
* Use new cic_paint with halo
* Fix seed for distributed normal
* Wrap interpolation function to avoid all gather
* Return normal order frequencies for single GPU
* add example
* format
* add optimised bench script
* times in ms
* add lpt2
* update benchmark and add slurm
* Visualize only final field
* Update scripts/distributed_pm.py
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
* Adjust pencil type for frequencies
* fix painting issue with slabs
* Shared operation in fourrier space now take inverted sharding axis for
slabs
* add assert to make pyright happy
* adjust test for hpc-plotter
* add PMWD test
* bench
* format
* added github workflow
* fix formatting from main
* Update for jaxDecomp pure JAX
* revert single halo extent change
* update for latest jaxDecomp
* remove fourrier_space in autoshmap
* make normal_field work with single controller
* format
* make distributed pm work in single controller
* merge bench_pm
* update to leapfrog
* add a strict dependency on jaxdecomp
* global mesh no longer needed
* kernels.py no longer uses global mesh
* quick fix in distributed
* pm.py no longer uses global mesh
* painting.py no longer uses global mesh
* update demo script
* quick fix in kernels
* quick fix in distributed
* update demo
* merge hugos LPT2 code
* format
* Small fix
* format
* remove duplicate get_ode_fn
* update visualizer
* update compensate CIC
* By default check_rep is false for shard_map
* remove experimental distributed code
* update PGDCorrection and neural ode to use new fft3d
* jaxDecomp pfft3d promotes to complex automatically
* remove deprecated stuff
* fix painting issue with read_cic
* use jnp interp instead of jc interp
* delete old slurms
* add notebook examples
* apply formatting
* add distributed zeros
* fix code in LPT2
* jit cic_paint
* update notebooks
* apply formating
* get local shape and zeros can be used by users
* add a user facing function to create uniform particle grid
* use jax interp instead of jax_cosmo
* use float64 for enmeshing
* Allow applying weights with relative cic paint
* Weights can be traced
* remove script folder
* update example notebooks
* delete outdated design file
* add readme for tutorials
* update readme
* fix small error
* forgot particles in multi host
* clarifying why cic_paint_dx is slower
* clarifying the halo size dependence on the box size
* ability to choose snapshots number with MultiHost script
* Adding animation notebook
* Put plotting in package
* Add finite difference laplace kernel + powerspec functions from Hugo
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
* Put plotting utils in package
* By default use absoulute painting with
* update code
* update notebooks
* add tests
* Upgrade setup.py to pyproject
* Format
* format tests
* update test dependencies
* add test workflow
* fix deprecated FftType in jaxpm.kernels
* Add aboucaud comments
* JAX version is 0.4.35 until Diffrax new release
* add numpy explicitly as dependency for tests
* fix install order for tests
* add numpy to be installed
* enforce no build isolation for fastpm
* pip install jaxpm test without build isolation
* bump jaxdecomp version
* revert test workflow
* remove outdated tests
---------
Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d