mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
Format
This commit is contained in:
parent
7823fdaf98
commit
af29c4005d
7 changed files with 68 additions and 63 deletions
|
@ -222,8 +222,7 @@ def cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||||
pmid = pmid.reshape([-1, 3])
|
pmid = pmid.reshape([-1, 3])
|
||||||
disp = disp.reshape([-1, 3])
|
disp = disp.reshape([-1, 3])
|
||||||
|
|
||||||
return gather(pmid, disp,
|
return gather(pmid, disp, grid_mesh).reshape(original_shape)
|
||||||
grid_mesh).reshape(original_shape)
|
|
||||||
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(2, 3))
|
@partial(jax.jit, static_argnums=(2, 3))
|
||||||
|
|
|
@ -25,12 +25,15 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||||
return remainder, chunks
|
return remainder, chunks
|
||||||
|
|
||||||
|
|
||||||
def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape):
|
def enmesh(base_indices, displacements, cell_size, base_shape, offset,
|
||||||
|
new_cell_size, new_shape):
|
||||||
"""Multilinear enmeshing."""
|
"""Multilinear enmeshing."""
|
||||||
base_indices = jnp.asarray(base_indices)
|
base_indices = jnp.asarray(base_indices)
|
||||||
displacements = jnp.asarray(displacements)
|
displacements = jnp.asarray(displacements)
|
||||||
with jax.experimental.enable_x64():
|
with jax.experimental.enable_x64():
|
||||||
cell_size = jnp.float64(cell_size) if new_cell_size is not None else jnp.array(cell_size, dtype=displacements.dtype)
|
cell_size = jnp.float64(
|
||||||
|
cell_size) if new_cell_size is not None else jnp.array(
|
||||||
|
cell_size, dtype=displacements.dtype)
|
||||||
if base_shape is not None:
|
if base_shape is not None:
|
||||||
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||||
offset = jnp.float64(offset)
|
offset = jnp.float64(offset)
|
||||||
|
@ -40,12 +43,14 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_
|
||||||
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||||
|
|
||||||
spatial_dim = base_indices.shape[1]
|
spatial_dim = base_indices.shape[1]
|
||||||
neighbor_offsets = (jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
|
neighbor_offsets = (
|
||||||
|
jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
|
||||||
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
|
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
|
||||||
|
|
||||||
if new_cell_size is not None:
|
if new_cell_size is not None:
|
||||||
particle_positions = base_indices * cell_size + displacements - offset
|
particle_positions = base_indices * cell_size + displacements - offset
|
||||||
particle_positions = particle_positions[:, jnp.newaxis] # insert neighbor axis
|
particle_positions = particle_positions[:, jnp.
|
||||||
|
newaxis] # insert neighbor axis
|
||||||
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
|
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
|
||||||
|
|
||||||
if base_shape is not None:
|
if base_shape is not None:
|
||||||
|
@ -56,7 +61,9 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_
|
||||||
new_displacements = particle_positions - new_indices * new_cell_size
|
new_displacements = particle_positions - new_indices * new_cell_size
|
||||||
|
|
||||||
if base_shape is not None:
|
if base_shape is not None:
|
||||||
new_displacements -= jnp.rint(new_displacements / grid_length) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
new_displacements -= jnp.rint(
|
||||||
|
new_displacements / grid_length
|
||||||
|
) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||||
|
|
||||||
new_indices = new_indices.astype(base_indices.dtype)
|
new_indices = new_indices.astype(base_indices.dtype)
|
||||||
new_displacements = new_displacements.astype(displacements.dtype)
|
new_displacements = new_displacements.astype(displacements.dtype)
|
||||||
|
|
14
jaxpm/pm.py
14
jaxpm/pm.py
|
@ -1,9 +1,7 @@
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
|
|
||||||
from jaxpm.distributed import (fft3d, ifft3d,
|
from jaxpm.distributed import fft3d, ifft3d, normal_field
|
||||||
normal_field)
|
|
||||||
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||||
growth_rate, growth_rate_second)
|
growth_rate, growth_rate_second)
|
||||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||||
|
@ -27,7 +25,8 @@ def pm_forces(positions,
|
||||||
mesh_shape = delta.shape
|
mesh_shape = delta.shape
|
||||||
|
|
||||||
if paint_absolute_pos:
|
if paint_absolute_pos:
|
||||||
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape , device=sharding),
|
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape,
|
||||||
|
device=sharding),
|
||||||
pos,
|
pos,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
|
@ -72,7 +71,8 @@ def lpt(cosmo,
|
||||||
"""
|
"""
|
||||||
paint_absolute_pos = particles is not None
|
paint_absolute_pos = particles is not None
|
||||||
if particles is None:
|
if particles is None:
|
||||||
particles = jnp.zeros_like(initial_conditions , shape=(*initial_conditions.shape , 3))
|
particles = jnp.zeros_like(initial_conditions,
|
||||||
|
shape=(*initial_conditions.shape, 3))
|
||||||
|
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||||
|
@ -172,7 +172,8 @@ def make_ode_fn(mesh_shape,
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def make_diffrax_ode(cosmo, mesh_shape,
|
def make_diffrax_ode(cosmo,
|
||||||
|
mesh_shape,
|
||||||
paint_absolute_pos=True,
|
paint_absolute_pos=True,
|
||||||
halo_size=0,
|
halo_size=0,
|
||||||
sharding=None):
|
sharding=None):
|
||||||
|
@ -199,6 +200,7 @@ def make_diffrax_ode(cosmo, mesh_shape,
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def pgd_correction(pos, mesh_shape, params):
|
def pgd_correction(pos, mesh_shape, params):
|
||||||
"""
|
"""
|
||||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
|
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
|
||||||
|
|
|
@ -5,7 +5,6 @@ import numpy as np
|
||||||
from jax.scipy.stats import norm
|
from jax.scipy.stats import norm
|
||||||
from scipy.special import legendre
|
from scipy.special import legendre
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||||
'cross_correlation_coefficients', 'gaussian_smoothing'
|
'cross_correlation_coefficients', 'gaussian_smoothing'
|
||||||
|
|
|
@ -18,7 +18,7 @@ import numpy as np
|
||||||
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm,
|
||||||
PIDController, SaveAt, diffeqsolve)
|
PIDController, SaveAt, diffeqsolve)
|
||||||
from jax.experimental.mesh_utils import create_device_mesh
|
from jax.experimental.mesh_utils import create_device_mesh
|
||||||
from jax.experimental.multihost_utils import (process_allgather)
|
from jax.experimental.multihost_utils import process_allgather
|
||||||
from jax.sharding import Mesh, NamedSharding
|
from jax.sharding import Mesh, NamedSharding
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# Parameterized fixture for mesh_shape
|
# Parameterized fixture for mesh_shape
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
os.environ["EQX_ON_ERROR"] = "nan"
|
os.environ["EQX_ON_ERROR"] = "nan"
|
||||||
setup_done = False
|
setup_done = False
|
||||||
on_cluster = False
|
on_cluster = False
|
||||||
|
@ -10,6 +12,7 @@ def is_on_cluster():
|
||||||
global on_cluster
|
global on_cluster
|
||||||
return on_cluster
|
return on_cluster
|
||||||
|
|
||||||
|
|
||||||
def initialize_distributed():
|
def initialize_distributed():
|
||||||
global setup_done
|
global setup_done
|
||||||
global on_cluster
|
global on_cluster
|
||||||
|
@ -26,25 +29,16 @@ def initialize_distributed():
|
||||||
setup_done = True
|
setup_done = True
|
||||||
on_cluster = False
|
on_cluster = False
|
||||||
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
||||||
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
os.environ[
|
||||||
|
"XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
|
||||||
import jax
|
import jax
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def setup_and_teardown_session():
|
|
||||||
# Code to run at the start of the session
|
|
||||||
print("Starting session...")
|
|
||||||
initialize_distributed()
|
|
||||||
# Setup code here
|
|
||||||
# e.g., connecting to a database, initializing some resources, etc.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(
|
@pytest.fixture(
|
||||||
scope="session",
|
scope="session",
|
||||||
params=[
|
params=[
|
||||||
((64, 64, 64) , (512., 512., 512.)), # BOX
|
((32, 32, 32), (256., 256., 256.)), # BOX
|
||||||
((64, 64, 128) , (256. , 256. , 512.)), # RECTANGULAR
|
((32, 32, 64), (256., 256., 512.)), # RECTANGULAR
|
||||||
])
|
])
|
||||||
def simulation_config(request):
|
def simulation_config(request):
|
||||||
return request.param
|
return request.param
|
||||||
|
@ -55,11 +49,11 @@ def lpt_scale_factor(request):
|
||||||
return request.param
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def cosmo():
|
def cosmo():
|
||||||
from jax_cosmo import Cosmology
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
from jax_cosmo import Cosmology
|
||||||
Planck18 = partial(
|
Planck18 = partial(
|
||||||
Cosmology,
|
Cosmology,
|
||||||
# Omega_m = 0.3111
|
# Omega_m = 0.3111
|
||||||
|
@ -85,9 +79,9 @@ def particle_mesh(simulation_config):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def fpm_initial_conditions(cosmo, particle_mesh):
|
def fpm_initial_conditions(cosmo, particle_mesh):
|
||||||
from jax import numpy as jnp
|
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from jax import numpy as jnp
|
||||||
|
|
||||||
# Generate initial particle positions
|
# Generate initial particle positions
|
||||||
grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype(
|
grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype(
|
||||||
|
@ -117,7 +111,8 @@ def initial_conditions(fpm_initial_conditions):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def solver(cosmo, particle_mesh):
|
def solver(cosmo, particle_mesh):
|
||||||
from fastpm.core import Solver, Cosmology as FastPMCosmology
|
from fastpm.core import Cosmology as FastPMCosmology
|
||||||
|
from fastpm.core import Solver
|
||||||
ref_cosmo = FastPMCosmology(cosmo)
|
ref_cosmo = FastPMCosmology(cosmo)
|
||||||
return Solver(particle_mesh, ref_cosmo, B=1)
|
return Solver(particle_mesh, ref_cosmo, B=1)
|
||||||
|
|
||||||
|
@ -150,8 +145,8 @@ def fpm_lpt2_field(fpm_lpt2, particle_mesh):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
|
def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
|
||||||
from fastpm.core import leapfrog
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from fastpm.core import leapfrog
|
||||||
|
|
||||||
if lpt_scale_factor == 0.8:
|
if lpt_scale_factor == 0.8:
|
||||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||||
|
@ -166,8 +161,8 @@ def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor):
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor):
|
||||||
from fastpm.core import leapfrog
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from fastpm.core import leapfrog
|
||||||
|
|
||||||
if lpt_scale_factor == 0.8:
|
if lpt_scale_factor == 0.8:
|
||||||
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
pytest.skip("Do not run nbody simulation from scale factor 0.8")
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
def MSE(x, y):
|
def MSE(x, y):
|
||||||
return jnp.mean((x - y)**2)
|
return jnp.mean((x - y)**2)
|
||||||
|
|
||||||
|
|
||||||
def MSE_3D(x, y):
|
def MSE_3D(x, y):
|
||||||
return ((x - y)**2).mean(axis=0)
|
return ((x - y)**2).mean(axis=0)
|
||||||
|
|
||||||
|
|
||||||
def MSRE(x, y):
|
def MSRE(x, y):
|
||||||
return jnp.mean(((x - y) / y)**2)
|
return jnp.mean(((x - y) / y)**2)
|
Loading…
Add table
Reference in a new issue