JAX-powered Cosmological Particle-Mesh N-body Solver
Find a file
Wassim KABALAN 6693e5c725
Some checks failed
Code Formatting / formatting (push) Failing after 4m30s
Tests / run_tests (3.10) (push) Failing after 1m41s
Tests / run_tests (3.11) (push) Failing after 1m42s
Tests / run_tests (3.12) (push) Failing after 1m15s
Fix sharding error (#37)
* Use cosmo as arg for the ODE function

* Update examples

* format

* notebook update

* fix tests

* add correct annotations for weights in painting and warning for cic_paint in distributed pm

* update test_against_fpm

* update distributed tests and add jacfwd jacrev and vmap tests

* format

* add Caveats to notebook readme

* final touches

* update Growth.py to allow using FastPM solver

* fix 2D painting when input is (X , Y , 2) shape

* update cic read halo size and notebooks examples

* Allow env variable control of caching in growth

* Format

* update test jax version

* update notebooks/03-MultiGPU_PM_Halo.ipynb

* update numpy install in wf

* update tolerance :)

* reorganize install in test workflow

* update tests

* add mpi4py

* update tests.yml

* update tests

* update wf

* format

* make normal_field signature consistent with jax.random.normal

* update by default normal_field dtype to match JAX

* format

* debug test workflow

* format

* debug test workflow

* updating tests

* fix accuracy

* fixed tolerance

* adding caching

* Update conftest.py

* Update tolerance and precision settings in distributed PM tests

* revererting back changes to growth.py

---------

Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
2025-06-28 23:07:31 +02:00
.github/workflows Fix sharding error (#37) 2025-06-28 23:07:31 +02:00
jaxpm Fix sharding error (#37) 2025-06-28 23:07:31 +02:00
notebooks Fix sharding error (#37) 2025-06-28 23:07:31 +02:00
tests Fix sharding error (#37) 2025-06-28 23:07:31 +02:00
.all-contributorsrc docs: update .all-contributorsrc [skip ci] 2024-12-07 19:54:33 +00:00
.gitignore Pypi upload compatible version (#33) 2024-12-21 11:47:13 -05:00
.pre-commit-config.yaml jaxdecomp proto (#21) 2024-12-20 05:44:02 -05:00
LICENSE Pypi upload compatible version (#33) 2024-12-21 11:47:13 -05:00
MANIFEST.in Pypi upload compatible version (#33) 2024-12-21 11:47:13 -05:00
pyproject.toml Fix pfft gradients (#34) 2024-12-22 12:47:42 -05:00
pytest.ini jaxdecomp proto (#21) 2024-12-20 05:44:02 -05:00
README.md minor typo fix 2024-12-21 15:28:20 -05:00
requirements-test.txt Fix sharding error (#37) 2025-06-28 23:07:31 +02:00

JaxPM

Notebook PyPI version Tests All Contributors

JAX-powered Cosmological Particle-Mesh N-body Solver

Note

The new JaxPM v0.1.xx supports multi-GPU model distribution while remaining compatible with previous releases. These significant changes are still under development and testing, so please report any issues you encounter. For the older but more stable version, install:

pip install jaxpm==0.0.2

Install

Basic installation can be done using pip:

pip install jaxpm

For more advanced installation for optimized distribution on gpu clusters, please install jaxDecomp first. See instructions here.

Goals

Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:

  • Keep implementation simple and readable, in pure NumPy API
  • Any order forward and backward automatic differentiation
  • Support automated batching using vmap
  • Compatibility with external optimizer libraries like optax
  • Now fully distributable on multi-GPU and multi-node systems using jaxDecomp working withJAX v0.4.35

Open development and use

Current expectations are:

  • This project is and will remain open source, and usable without any restrictions for any purposes
  • Will be a simple publication on The Journal of Open Source Software
  • Everyone is welcome to contribute, and can join the JOSS publication (until it is submitted to the journal).
  • Anyone (including main contributors) can use this code as a framework to build and publish their own applications, with no expectation that they need to extend authorship to all jaxpm developers.

Getting Started

To dive into JaxPMs capabilities, please explore the notebook section for detailed tutorials and examples on various setups, from single-device simulations to multi-host configurations. You can find the notebooks' README here for a structured guide through each tutorial.

Contributors

Thanks goes to these wonderful people (emoji key):

Francois Lanusse
Francois Lanusse

🤔
Denise Lanzieri
Denise Lanzieri

💻
Wassim KABALAN
Wassim KABALAN

💻 🚇 👀
Hugo Simon-Onfroy
Hugo Simon-Onfroy

💻
Alexandre Boucaud
Alexandre Boucaud

👀

This project follows the all-contributors specification. Contributions of any kind welcome!