Compare commits

...

13 commits

Author SHA1 Message Date
Wassim Kabalan
5807e1d3f4 format 2025-06-08 15:37:06 +02:00
Wassim Kabalan
a7fcba0e1f update wf 2025-06-08 15:35:57 +02:00
Wassim Kabalan
c6a7dd4e4e update tests 2025-06-08 15:27:32 +02:00
Wassim Kabalan
e0ba85fb58 update tests.yml 2025-06-08 12:03:44 +02:00
Wassim Kabalan
12eddc4e6a add mpi4py 2025-06-08 11:48:11 +02:00
Wassim Kabalan
0eb4c371e3 update tests 2025-06-08 11:43:15 +02:00
Wassim Kabalan
3be619a2db reorganize install in test workflow 2025-06-08 11:36:06 +02:00
Wassim Kabalan
f0b849cf5f update tolerance :) 2025-06-08 11:35:57 +02:00
Wassim Kabalan
995cc4c78c update numpy install in wf 2025-06-08 11:26:32 +02:00
Wassim Kabalan
7b7205e3b3 update notebooks/03-MultiGPU_PM_Halo.ipynb 2025-06-08 11:25:51 +02:00
Wassim Kabalan
6aacd81bd6 update test jax version 2025-06-08 10:56:34 +02:00
Wassim Kabalan
49c93aacf6 Format 2025-06-08 10:45:20 +02:00
Wassim Kabalan
41ae41ace3 Allow env variable control of caching in growth 2025-06-08 10:45:04 +02:00
8 changed files with 136 additions and 106 deletions

View file

@ -29,12 +29,11 @@ jobs:
run: |
sudo apt-get install -y libopenmpi-dev
python -m pip install --upgrade pip
pip install jax==0.4.35
pip install numpy setuptools cython wheel
pip install git+https://github.com/MP-Gadget/pfft-python
pip install git+https://github.com/MP-Gadget/pmesh
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
pip install -r requirements-test.txt
pip install jax
pip install setuptools cython wheel mpi4py
pip install -r requirements-test.txt --no-build-isolation
pip install pytest
pip install diffrax
pip install .
- name: Run Single Device Tests
@ -43,4 +42,4 @@ jobs:
pytest -v -m "not distributed"
- name: Run Distributed tests
run: |
pytest -v -m distributed
pytest -v tests/test_distributed_pm.py

View file

@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
axis=-1)
def normal_field(mesh_shape, seed, sharding=None , dtype='float32'):
def normal_field(mesh_shape, seed, sharding=None, dtype='float32'):
"""Generate a Gaussian random field with the given power spectrum."""
gpu_mesh = sharding.mesh if sharding is not None else None
if gpu_mesh is not None and not (gpu_mesh.empty):

View file

@ -1,3 +1,5 @@
import os
import jax.numpy as np
from jax.numpy import interp
from jax_cosmo.background import *
@ -243,7 +245,11 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
Growth factor computed at requested scale factor
"""
# Check if growth has already been computed
#if not "background.growth_factor" in cosmo._workspace.keys():
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
):
cache = cosmo._workspace["background.growth_factor"]
else:
# Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps)
@ -281,7 +287,6 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab
h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab
cache = {
"a": atab,
"g": gtab,
@ -291,8 +296,10 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
"f2": f2tab,
"h2": h2tab,
}
if CACHING_ACTIVATED:
cosmo._workspace["background.growth_factor"] = cache
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) , cache
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0), cache
def _growth_rate_ODE(cosmo, a):
@ -317,6 +324,7 @@ def _growth_rate_ODE(cosmo, a):
cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1]
return interp(a, cache["a"], cache["f"])
def _growth_factor_second_ODE(cosmo, a):
"""Compute second order growth factor D2(a) at a given scale factor,
normalised such that D(a=1) = 1.
@ -384,7 +392,11 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
"""
# Check if growth has already been computed, if not, compute it
if not "background.growth_factor" in cosmo._workspace.keys():
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
):
cache = cosmo._workspace["background.growth_factor"]
else:
# Compute tabulated array
atab = np.logspace(log10_amin, 0.0, steps)
@ -395,9 +407,8 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
gtab = gtab / gtab[-1] # Normalize to a=1.
cache = {"a": atab, "g": gtab}
if CACHING_ACTIVATED:
cosmo._workspace["background.growth_factor"] = cache
else:
cache = cosmo._workspace["background.growth_factor"]
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
@ -522,6 +533,7 @@ def gp(cosmo, a):
D1f = f1 * g1 / a
return D1f
def dGfa(cosmo, a):
r""" Derivative of Gf against a

View file

@ -240,7 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
halo_size = jax.tree.map(lambda x: x//2, halo_size)
# Halo size is halved for the read operation
# We only need to read the density field
# while in the painting operation we need to exchange and reduce the halo
# We chose to do that since it is much easier to write a custom jvp rule for exchange
# while it is a bit harder if there is a reduction involved
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
grid_mesh = halo_exchange(grid_mesh,
halo_extents=halo_extents,

View file

@ -30,15 +30,12 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
"""Multilinear enmeshing."""
base_indices = jnp.asarray(base_indices)
displacements = jnp.asarray(displacements)
with jax.experimental.enable_x64():
cell_size = jnp.float64(
cell_size) if new_cell_size is not None else jnp.array(
cell_size, dtype=displacements.dtype)
cell_size = jnp.array(cell_size, dtype=displacements.dtype)
if base_shape is not None:
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
offset = jnp.float64(offset)
offset = offset.astype(base_indices.dtype)
if new_cell_size is not None:
new_cell_size = jnp.float64(new_cell_size)
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
if new_shape is not None:
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)

File diff suppressed because one or more lines are too long

View file

@ -1,5 +1,3 @@
pytest>=8.0.0
diffrax
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
pmesh @ git+https://github.com/MP-Gadget/pmesh
fastpm @ git+https://github.com/ASKabalan/fastpm-python

View file

@ -22,7 +22,7 @@ from jaxpm.distributed import fft3d, ifft3d
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
_TOLERANCE = 1e-1 # 🙃🙃
_TOLERANCE = 1e-6 # 🎉🎉🎉
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]