Commit graph

16 commits

Author SHA1 Message Date
5868c71522 fix: use proper kernels for PM integration 2025-02-10 07:49:34 +01:00
Wassim KABALAN
df8602b318 jaxdecomp proto (#21)
* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
2024-12-20 05:44:02 -05:00
Hugo Simonfroy
9b21eed3b5 2lpt, get_ode, invlaplace, docstrings 2024-07-31 00:46:53 +02:00
Francois Lanusse
b949827e92 Update jaxpm/pm.py 2024-07-19 10:49:51 -04:00
Francois Lanusse
9a279d2d6c Merge branch 'main' into neural_ode 2024-07-19 10:48:09 -04:00
EiffL
f28442bb48 Applying formatting 2024-07-09 14:54:34 -04:00
denise lanzieri
84b79af7f8 creoss correlation function 2022-06-18 18:23:46 +02:00
denise lanzieri
8b885450a8 few adjustments to PGD correction 2022-06-13 17:17:19 +02:00
denise lanzieri
d8a1dbe210 neural ode added 2022-06-11 14:28:30 +02:00
Denise Lanzieri
3df0c05d29 Update jaxpm/pm.py
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
2022-05-18 09:59:59 +02:00
Denise Lanzieri
dd71d359e0 Update jaxpm/pm.py
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
2022-05-18 09:58:46 +02:00
Denise Lanzieri
77827fcf44 Update jaxpm/pm.py
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
2022-05-18 09:58:34 +02:00
Denise Lanzieri
85b2f4f097 Update jaxpm/pm.py
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
2022-05-18 09:55:31 +02:00
denise lanzieri
6b6b414195 PGD 2022-05-17 15:28:30 +02:00
EiffL
079cebbdea fix normalization of init cond 2022-03-25 22:34:13 +01:00
EiffL
8543246f62 Adds demo and notebooks 2022-02-14 01:59:12 +01:00