Fix sharding error (#37)
Some checks failed
Code Formatting / formatting (push) Failing after 4m30s
Tests / run_tests (3.10) (push) Failing after 1m41s
Tests / run_tests (3.11) (push) Failing after 1m42s
Tests / run_tests (3.12) (push) Failing after 1m15s

* Use cosmo as arg for the ODE function

* Update examples

* format

* notebook update

* fix tests

* add correct annotations for weights in painting and warning for cic_paint in distributed pm

* update test_against_fpm

* update distributed tests and add jacfwd jacrev and vmap tests

* format

* add Caveats to notebook readme

* final touches

* update Growth.py to allow using FastPM solver

* fix 2D painting when input is (X , Y , 2) shape

* update cic read halo size and notebooks examples

* Allow env variable control of caching in growth

* Format

* update test jax version

* update notebooks/03-MultiGPU_PM_Halo.ipynb

* update numpy install in wf

* update tolerance :)

* reorganize install in test workflow

* update tests

* add mpi4py

* update tests.yml

* update tests

* update wf

* format

* make normal_field signature consistent with jax.random.normal

* update by default normal_field dtype to match JAX

* format

* debug test workflow

* format

* debug test workflow

* updating tests

* fix accuracy

* fixed tolerance

* adding caching

* Update conftest.py

* Update tolerance and precision settings in distributed PM tests

* revererting back changes to growth.py

---------

Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
This commit is contained in:
Wassim KABALAN 2025-06-28 23:07:31 +02:00 committed by GitHub
parent cb2a7ab17f
commit 6693e5c725
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 675 additions and 298 deletions

View file

@ -44,7 +44,7 @@ def simulation_config(request):
return request.param
@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8])
@pytest.fixture(scope="session", params=[0.1, 0.2])
def lpt_scale_factor(request):
return request.param
@ -151,7 +151,7 @@ def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
finalstate = solver.nbody(fpm_lpt1, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value
@ -167,7 +167,7 @@ def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
if lpt_scale_factor == 0.8:
pytest.skip("Do not run nbody simulation from scale factor 0.8")
stages = np.linspace(lpt_scale_factor, 1.0, 10, endpoint=True)
stages = np.linspace(lpt_scale_factor, 1.0, 100, endpoint=True)
finalstate = solver.nbody(fpm_lpt2, leapfrog(stages))
fpm_mesh = particle_mesh.paint(finalstate.X).value