mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-14 09:51:11 +00:00
Compare commits
13 commits
e7112e0c25
...
5807e1d3f4
Author | SHA1 | Date | |
---|---|---|---|
|
5807e1d3f4 | ||
|
a7fcba0e1f | ||
|
c6a7dd4e4e | ||
|
e0ba85fb58 | ||
|
12eddc4e6a | ||
|
0eb4c371e3 | ||
|
3be619a2db | ||
|
f0b849cf5f | ||
|
995cc4c78c | ||
|
7b7205e3b3 | ||
|
6aacd81bd6 | ||
|
49c93aacf6 | ||
|
41ae41ace3 |
8 changed files with 136 additions and 106 deletions
13
.github/workflows/tests.yml
vendored
13
.github/workflows/tests.yml
vendored
|
@ -29,12 +29,11 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get install -y libopenmpi-dev
|
sudo apt-get install -y libopenmpi-dev
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install jax==0.4.35
|
pip install jax
|
||||||
pip install numpy setuptools cython wheel
|
pip install setuptools cython wheel mpi4py
|
||||||
pip install git+https://github.com/MP-Gadget/pfft-python
|
pip install -r requirements-test.txt --no-build-isolation
|
||||||
pip install git+https://github.com/MP-Gadget/pmesh
|
pip install pytest
|
||||||
pip install git+https://github.com/ASKabalan/fastpm-python --no-build-isolation
|
pip install diffrax
|
||||||
pip install -r requirements-test.txt
|
|
||||||
pip install .
|
pip install .
|
||||||
|
|
||||||
- name: Run Single Device Tests
|
- name: Run Single Device Tests
|
||||||
|
@ -43,4 +42,4 @@ jobs:
|
||||||
pytest -v -m "not distributed"
|
pytest -v -m "not distributed"
|
||||||
- name: Run Distributed tests
|
- name: Run Distributed tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -m distributed
|
pytest -v tests/test_distributed_pm.py
|
||||||
|
|
|
@ -166,7 +166,7 @@ def uniform_particles(mesh_shape, sharding=None):
|
||||||
axis=-1)
|
axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def normal_field(mesh_shape, seed, sharding=None , dtype='float32'):
|
def normal_field(mesh_shape, seed, sharding=None, dtype='float32'):
|
||||||
"""Generate a Gaussian random field with the given power spectrum."""
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
if gpu_mesh is not None and not (gpu_mesh.empty):
|
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||||
|
|
110
jaxpm/growth.py
110
jaxpm/growth.py
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import jax.numpy as np
|
import jax.numpy as np
|
||||||
from jax.numpy import interp
|
from jax.numpy import interp
|
||||||
from jax_cosmo.background import *
|
from jax_cosmo.background import *
|
||||||
|
@ -243,56 +245,61 @@ def _growth_factor_ODE(cosmo, a, log10_amin=-3, steps=256, eps=1e-4):
|
||||||
Growth factor computed at requested scale factor
|
Growth factor computed at requested scale factor
|
||||||
"""
|
"""
|
||||||
# Check if growth has already been computed
|
# Check if growth has already been computed
|
||||||
#if not "background.growth_factor" in cosmo._workspace.keys():
|
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
|
||||||
# Compute tabulated array
|
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
|
||||||
atab = np.logspace(log10_amin, 0.0, steps)
|
):
|
||||||
|
cache = cosmo._workspace["background.growth_factor"]
|
||||||
|
else:
|
||||||
|
# Compute tabulated array
|
||||||
|
atab = np.logspace(log10_amin, 0.0, steps)
|
||||||
|
|
||||||
def D_derivs(y, x):
|
def D_derivs(y, x):
|
||||||
q = (2.0 - 0.5 *
|
q = (2.0 - 0.5 *
|
||||||
(Omega_m_a(cosmo, x) +
|
(Omega_m_a(cosmo, x) +
|
||||||
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
|
(1.0 + 3.0 * w(cosmo, x)) * Omega_de_a(cosmo, x))) / x
|
||||||
r = 1.5 * Omega_m_a(cosmo, x) / x / x
|
r = 1.5 * Omega_m_a(cosmo, x) / x / x
|
||||||
|
|
||||||
g1, g2 = y[0]
|
g1, g2 = y[0]
|
||||||
f1, f2 = y[1]
|
f1, f2 = y[1]
|
||||||
dy1da = [f1, -q * f1 + r * g1]
|
dy1da = [f1, -q * f1 + r * g1]
|
||||||
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
|
dy2da = [f2, -q * f2 + r * g2 - r * g1**2]
|
||||||
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
|
return np.array([[dy1da[0], dy2da[0]], [dy1da[1], dy2da[1]]])
|
||||||
|
|
||||||
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
|
y0 = np.array([[atab[0], -3.0 / 7 * atab[0]**2],
|
||||||
[1.0, -6.0 / 7 * atab[0]]])
|
[1.0, -6.0 / 7 * atab[0]]])
|
||||||
y = odeint(D_derivs, y0, atab)
|
y = odeint(D_derivs, y0, atab)
|
||||||
|
|
||||||
# compute second order derivatives growth
|
# compute second order derivatives growth
|
||||||
dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab)
|
dyda2 = D_derivs(np.transpose(y, (1, 2, 0)), atab)
|
||||||
dyda2 = np.transpose(dyda2, (2, 0, 1))
|
dyda2 = np.transpose(dyda2, (2, 0, 1))
|
||||||
|
|
||||||
# Normalize results
|
# Normalize results
|
||||||
y1 = y[:, 0, 0]
|
y1 = y[:, 0, 0]
|
||||||
gtab = y1 / y1[-1]
|
gtab = y1 / y1[-1]
|
||||||
y2 = y[:, 0, 1]
|
y2 = y[:, 0, 1]
|
||||||
g2tab = y2 / y2[-1]
|
g2tab = y2 / y2[-1]
|
||||||
# To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da
|
# To transform from dD/da to dlnD/dlna: dlnD/dlna = a / D dD/da
|
||||||
ftab = y[:, 1, 0] / y1[-1] * atab / gtab
|
ftab = y[:, 1, 0] / y1[-1] * atab / gtab
|
||||||
f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab
|
f2tab = y[:, 1, 1] / y2[-1] * atab / g2tab
|
||||||
# Similarly for second order derivatives
|
# Similarly for second order derivatives
|
||||||
# Note: these factors are not accessible as parent functions yet
|
# Note: these factors are not accessible as parent functions yet
|
||||||
# since it is unclear what to refer to them with.
|
# since it is unclear what to refer to them with.
|
||||||
htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab
|
htab = dyda2[:, 1, 0] / y1[-1] * atab / gtab
|
||||||
h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab
|
h2tab = dyda2[:, 1, 1] / y2[-1] * atab / g2tab
|
||||||
|
|
||||||
|
cache = {
|
||||||
|
"a": atab,
|
||||||
|
"g": gtab,
|
||||||
|
"f": ftab,
|
||||||
|
"h": htab,
|
||||||
|
"g2": g2tab,
|
||||||
|
"f2": f2tab,
|
||||||
|
"h2": h2tab,
|
||||||
|
}
|
||||||
|
if CACHING_ACTIVATED:
|
||||||
|
cosmo._workspace["background.growth_factor"] = cache
|
||||||
|
|
||||||
cache = {
|
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0), cache
|
||||||
"a": atab,
|
|
||||||
"g": gtab,
|
|
||||||
"f": ftab,
|
|
||||||
"h": htab,
|
|
||||||
"g2": g2tab,
|
|
||||||
"f2": f2tab,
|
|
||||||
"h2": h2tab,
|
|
||||||
}
|
|
||||||
|
|
||||||
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0) , cache
|
|
||||||
|
|
||||||
|
|
||||||
def _growth_rate_ODE(cosmo, a):
|
def _growth_rate_ODE(cosmo, a):
|
||||||
|
@ -313,10 +320,11 @@ def _growth_rate_ODE(cosmo, a):
|
||||||
Growth rate computed at requested scale factor
|
Growth rate computed at requested scale factor
|
||||||
"""
|
"""
|
||||||
# Check if growth has already been computed, if not, compute it
|
# Check if growth has already been computed, if not, compute it
|
||||||
|
|
||||||
cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1]
|
cache = _growth_factor_ODE(cosmo, np.atleast_1d(1.0))[1]
|
||||||
return interp(a, cache["a"], cache["f"])
|
return interp(a, cache["a"], cache["f"])
|
||||||
|
|
||||||
|
|
||||||
def _growth_factor_second_ODE(cosmo, a):
|
def _growth_factor_second_ODE(cosmo, a):
|
||||||
"""Compute second order growth factor D2(a) at a given scale factor,
|
"""Compute second order growth factor D2(a) at a given scale factor,
|
||||||
normalised such that D(a=1) = 1.
|
normalised such that D(a=1) = 1.
|
||||||
|
@ -384,7 +392,11 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Check if growth has already been computed, if not, compute it
|
# Check if growth has already been computed, if not, compute it
|
||||||
if not "background.growth_factor" in cosmo._workspace.keys():
|
CACHING_ACTIVATED = os.environ.get("JC_CACHE", "1") == "1"
|
||||||
|
if CACHING_ACTIVATED and "background.growth_factor" in cosmo._workspace.keys(
|
||||||
|
):
|
||||||
|
cache = cosmo._workspace["background.growth_factor"]
|
||||||
|
else:
|
||||||
# Compute tabulated array
|
# Compute tabulated array
|
||||||
atab = np.logspace(log10_amin, 0.0, steps)
|
atab = np.logspace(log10_amin, 0.0, steps)
|
||||||
|
|
||||||
|
@ -395,9 +407,8 @@ def _growth_factor_gamma(cosmo, a, log10_amin=-3, steps=128):
|
||||||
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
|
gtab = np.exp(odeint(integrand, np.log(atab[0]), np.log(atab)))
|
||||||
gtab = gtab / gtab[-1] # Normalize to a=1.
|
gtab = gtab / gtab[-1] # Normalize to a=1.
|
||||||
cache = {"a": atab, "g": gtab}
|
cache = {"a": atab, "g": gtab}
|
||||||
cosmo._workspace["background.growth_factor"] = cache
|
if CACHING_ACTIVATED:
|
||||||
else:
|
cosmo._workspace["background.growth_factor"] = cache
|
||||||
cache = cosmo._workspace["background.growth_factor"]
|
|
||||||
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
|
return np.clip(interp(a, cache["a"], cache["g"]), 0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@ -522,6 +533,7 @@ def gp(cosmo, a):
|
||||||
D1f = f1 * g1 / a
|
D1f = f1 * g1 / a
|
||||||
return D1f
|
return D1f
|
||||||
|
|
||||||
|
|
||||||
def dGfa(cosmo, a):
|
def dGfa(cosmo, a):
|
||||||
r""" Derivative of Gf against a
|
r""" Derivative of Gf against a
|
||||||
|
|
||||||
|
@ -592,4 +604,4 @@ def dGf2a(cosmo, a):
|
||||||
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
f2p = interp(np.log(a), np.log(cache['a']), f2p)
|
||||||
E_a = E(cosmo, a)
|
E_a = E(cosmo, a)
|
||||||
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
|
return (f2p * a**3 * E_a + D2f * a**3 * dEa(cosmo, a) +
|
||||||
3 * a**2 * E_a * D2f)
|
3 * a**2 * E_a * D2f)
|
||||||
|
|
|
@ -240,7 +240,12 @@ def _cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||||
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None):
|
||||||
|
|
||||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||||
halo_size = jax.tree.map(lambda x: x//2, halo_size)
|
# Halo size is halved for the read operation
|
||||||
|
# We only need to read the density field
|
||||||
|
# while in the painting operation we need to exchange and reduce the halo
|
||||||
|
# We chose to do that since it is much easier to write a custom jvp rule for exchange
|
||||||
|
# while it is a bit harder if there is a reduction involved
|
||||||
|
halo_size = jax.tree.map(lambda x: x // 2, halo_size)
|
||||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||||
grid_mesh = halo_exchange(grid_mesh,
|
grid_mesh = halo_exchange(grid_mesh,
|
||||||
halo_extents=halo_extents,
|
halo_extents=halo_extents,
|
||||||
|
|
|
@ -30,17 +30,14 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||||
"""Multilinear enmeshing."""
|
"""Multilinear enmeshing."""
|
||||||
base_indices = jnp.asarray(base_indices)
|
base_indices = jnp.asarray(base_indices)
|
||||||
displacements = jnp.asarray(displacements)
|
displacements = jnp.asarray(displacements)
|
||||||
with jax.experimental.enable_x64():
|
cell_size = jnp.array(cell_size, dtype=displacements.dtype)
|
||||||
cell_size = jnp.float64(
|
if base_shape is not None:
|
||||||
cell_size) if new_cell_size is not None else jnp.array(
|
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||||
cell_size, dtype=displacements.dtype)
|
offset = offset.astype(base_indices.dtype)
|
||||||
if base_shape is not None:
|
if new_cell_size is not None:
|
||||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
new_cell_size = jnp.array(new_cell_size, dtype=displacements.dtype)
|
||||||
offset = jnp.float64(offset)
|
if new_shape is not None:
|
||||||
if new_cell_size is not None:
|
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||||
new_cell_size = jnp.float64(new_cell_size)
|
|
||||||
if new_shape is not None:
|
|
||||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
|
||||||
|
|
||||||
spatial_dim = base_indices.shape[1]
|
spatial_dim = base_indices.shape[1]
|
||||||
neighbor_offsets = (
|
neighbor_offsets = (
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,5 +1,3 @@
|
||||||
pytest>=8.0.0
|
|
||||||
diffrax
|
|
||||||
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
|
pfft-python @ git+https://github.com/MP-Gadget/pfft-python
|
||||||
pmesh @ git+https://github.com/MP-Gadget/pmesh
|
pmesh @ git+https://github.com/MP-Gadget/pmesh
|
||||||
fastpm @ git+https://github.com/ASKabalan/fastpm-python
|
fastpm @ git+https://github.com/ASKabalan/fastpm-python
|
||||||
|
|
|
@ -22,7 +22,7 @@ from jaxpm.distributed import fft3d, ifft3d
|
||||||
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
from jaxpm.painting import cic_paint, cic_paint_dx # noqa : E402
|
||||||
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
|
from jaxpm.pm import lpt, make_diffrax_ode, pm_forces # noqa : E402
|
||||||
|
|
||||||
_TOLERANCE = 1e-1 # 🙃🙃
|
_TOLERANCE = 1e-6 # 🎉🎉🎉
|
||||||
|
|
||||||
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
|
pdims = [(1, 8), (8, 1), (4, 2), (2, 4)]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue