JAX-powered Cosmological Particle-Mesh N-body Solver
  • Jupyter Notebook 99.3%
  • Python 0.7%
Find a file
Wassim KABALAN 2b2e6ce65d
Some checks failed
Tests / run_tests (3.10) (push) Failing after 2m21s
Code Formatting / formatting (push) Successful in 2m24s
Tests / run_tests (3.11) (push) Failing after 2m24s
Tests / run_tests (3.12) (push) Failing after 1m29s
Adding a visibility mask computation to make masks from observer positions (#49)
* Add an analytical function to compute visibility mask from observer position

* Update test workflow install order
2025-11-03 09:52:24 +01:00
.github/workflows Update jax.experimental.shard_map to jax.shard_map and bump jax version (#50) 2025-10-30 11:51:18 +01:00
jaxpm Adding a visibility mask computation to make masks from observer positions (#49) 2025-11-03 09:52:24 +01:00
notebooks Adding a cell to Spherical_Painting_Methods.ipynb to show how to change observer position (#48) 2025-10-14 13:21:36 +02:00
tests Adding a visibility mask computation to make masks from observer positions (#49) 2025-11-03 09:52:24 +01:00
.all-contributorsrc docs: update .all-contributorsrc [skip ci] 2024-12-07 19:54:33 +00:00
.gitignore Pypi upload compatible version (#33) 2024-12-21 11:47:13 -05:00
.pre-commit-config.yaml added a prototype of RBF kernel (#46) 2025-10-09 17:30:49 +02:00
LICENSE Pypi upload compatible version (#33) 2024-12-21 11:47:13 -05:00
MANIFEST.in Pypi upload compatible version (#33) 2024-12-21 11:47:13 -05:00
pyproject.toml Update jax.experimental.shard_map to jax.shard_map and bump jax version (#50) 2025-10-30 11:51:18 +01:00
pytest.ini jaxdecomp proto (#21) 2024-12-20 05:44:02 -05:00
README.md Revise README to remove open development details 2025-10-14 09:55:03 +02:00
requirements-test.txt added a prototype of RBF kernel (#46) 2025-10-09 17:30:49 +02:00

JaxPM

Notebook PyPI version Tests All Contributors

JAX-powered Cosmological Particle-Mesh N-body Solver

Note

The new JaxPM v0.1.xx supports multi-GPU model distribution while remaining compatible with previous releases. These significant changes are still under development and testing, so please report any issues you encounter. For the older but more stable version, install:

pip install jaxpm==0.0.2

Install

Basic installation can be done using pip:

pip install jaxpm

For more advanced installation for optimized distribution on gpu clusters, please install jaxDecomp first. See instructions here.

Goals

Provide a modern infrastructure to support differentiable PM N-body simulations using JAX:

  • Keep implementation simple and readable, in pure NumPy API
  • Any order forward and backward automatic differentiation
  • Support automated batching using vmap
  • Compatibility with external optimizer libraries like optax
  • Now fully distributable on multi-GPU and multi-node systems using jaxDecomp working withJAX v0.4.35

Contributors

Thanks goes to these wonderful people (emoji key):

Francois Lanusse
Francois Lanusse

🤔
Denise Lanzieri
Denise Lanzieri

💻
Wassim KABALAN
Wassim KABALAN

💻 🚇 👀
Hugo Simon-Onfroy
Hugo Simon-Onfroy

💻
Alexandre Boucaud
Alexandre Boucaud

👀

This project follows the all-contributors specification. Contributions of any kind welcome!