mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-30 09:01:11 +00:00
Fix sharding error (#37)
* Use cosmo as arg for the ODE function * Update examples * format * notebook update * fix tests * add correct annotations for weights in painting and warning for cic_paint in distributed pm * update test_against_fpm * update distributed tests and add jacfwd jacrev and vmap tests * format * add Caveats to notebook readme * final touches * update Growth.py to allow using FastPM solver * fix 2D painting when input is (X , Y , 2) shape * update cic read halo size and notebooks examples * Allow env variable control of caching in growth * Format * update test jax version * update notebooks/03-MultiGPU_PM_Halo.ipynb * update numpy install in wf * update tolerance :) * reorganize install in test workflow * update tests * add mpi4py * update tests.yml * update tests * update wf * format * make normal_field signature consistent with jax.random.normal * update by default normal_field dtype to match JAX * format * debug test workflow * format * debug test workflow * updating tests * fix accuracy * fixed tolerance * adding caching * Update conftest.py * Update tolerance and precision settings in distributed PM tests * revererting back changes to growth.py --------- Co-authored-by: Francois Lanusse <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
This commit is contained in:
parent
cb2a7ab17f
commit
6693e5c725
17 changed files with 675 additions and 298 deletions
|
@ -2,6 +2,7 @@ import pytest
|
|||
from diffrax import Dopri5, ODETerm, PIDController, SaveAt, diffeqsolve
|
||||
from helpers import MSE, MSRE
|
||||
from jax import numpy as jnp
|
||||
from numpy.testing import assert_allclose
|
||||
|
||||
from jaxpm.distributed import uniform_particles
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx
|
||||
|
@ -10,6 +11,8 @@ from jaxpm.utils import power_spectrum
|
|||
|
||||
_TOLERANCE = 1e-4
|
||||
_PM_TOLERANCE = 1e-3
|
||||
_FIELD_RTOL = 1e-4
|
||||
_FIELD_ATOL = 1e-3
|
||||
|
||||
|
||||
@pytest.mark.single_device
|
||||
|
@ -34,7 +37,10 @@ def test_lpt_absolute(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert_allclose(lpt_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
|
@ -55,7 +61,10 @@ def test_lpt_relative(simulation_config, initial_conditions, lpt_scale_factor,
|
|||
_, jpm_ps = power_spectrum(lpt_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(lpt_field, fpm_ref_field) < _TOLERANCE
|
||||
assert_allclose(lpt_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _TOLERANCE
|
||||
|
||||
|
||||
|
@ -76,7 +85,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
a=lpt_scale_factor,
|
||||
order=order)
|
||||
|
||||
ode_fn = ODETerm(make_diffrax_ode(cosmo, mesh_shape))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-8,
|
||||
|
@ -95,6 +104,7 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
|
@ -105,7 +115,10 @@ def test_nbody_absolute(simulation_config, initial_conditions,
|
|||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert_allclose(final_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
||||
|
||||
|
@ -121,8 +134,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo, initial_conditions, a=lpt_scale_factor, order=order)
|
||||
|
||||
ode_fn = ODETerm(
|
||||
make_diffrax_ode(cosmo, mesh_shape, paint_absolute_pos=False))
|
||||
ode_fn = ODETerm(make_diffrax_ode(mesh_shape, paint_absolute_pos=False))
|
||||
|
||||
solver = Dopri5()
|
||||
controller = PIDController(rtol=1e-9,
|
||||
|
@ -141,6 +153,7 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
t1=1.0,
|
||||
dt0=None,
|
||||
y0=y0,
|
||||
args=cosmo,
|
||||
stepsize_controller=controller,
|
||||
saveat=saveat)
|
||||
|
||||
|
@ -151,5 +164,8 @@ def test_nbody_relative(simulation_config, initial_conditions,
|
|||
_, jpm_ps = power_spectrum(final_field, box_shape=box_shape)
|
||||
_, fpm_ps = power_spectrum(fpm_ref_field, box_shape=box_shape)
|
||||
|
||||
assert MSE(final_field, fpm_ref_field) < _PM_TOLERANCE
|
||||
assert_allclose(final_field,
|
||||
fpm_ref_field,
|
||||
rtol=_FIELD_RTOL,
|
||||
atol=_FIELD_ATOL)
|
||||
assert MSRE(jpm_ps, fpm_ps) < _PM_TOLERANCE
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue