diff --git a/jaxpm/painting.py b/jaxpm/painting.py index aec41e8..db695e3 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -222,12 +222,11 @@ def cic_read_dx_impl(grid_mesh, disp, halo_size): pmid = pmid.reshape([-1, 3]) disp = disp.reshape([-1, 3]) - return gather(pmid, disp, - grid_mesh).reshape(original_shape) + return gather(pmid, disp, grid_mesh).reshape(original_shape) @partial(jax.jit, static_argnums=(2, 3)) -def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None): +def cic_read_dx(grid_mesh, disp, halo_size=0, sharding=None): halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding) @@ -239,7 +238,7 @@ def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None): displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), gpu_mesh=gpu_mesh, in_specs=(spec), - out_specs=spec)(grid_mesh , disp) + out_specs=spec)(grid_mesh, disp) return displacements diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py index 916b457..cf68f9d 100644 --- a/jaxpm/painting_utils.py +++ b/jaxpm/painting_utils.py @@ -25,12 +25,15 @@ def _chunk_split(ptcl_num, chunk_size, *arrays): return remainder, chunks -def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape): +def enmesh(base_indices, displacements, cell_size, base_shape, offset, + new_cell_size, new_shape): """Multilinear enmeshing.""" base_indices = jnp.asarray(base_indices) displacements = jnp.asarray(displacements) with jax.experimental.enable_x64(): - cell_size = jnp.float64(cell_size) if new_cell_size is not None else jnp.array(cell_size, dtype=displacements.dtype) + cell_size = jnp.float64( + cell_size) if new_cell_size is not None else jnp.array( + cell_size, dtype=displacements.dtype) if base_shape is not None: base_shape = jnp.array(base_shape, dtype=base_indices.dtype) offset = jnp.float64(offset) @@ -40,12 +43,14 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_ new_shape = jnp.array(new_shape, dtype=base_indices.dtype) spatial_dim = base_indices.shape[1] - neighbor_offsets = (jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >> - jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1 + neighbor_offsets = ( + jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >> + jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1 if new_cell_size is not None: particle_positions = base_indices * cell_size + displacements - offset - particle_positions = particle_positions[:, jnp.newaxis] # insert neighbor axis + particle_positions = particle_positions[:, jnp. + newaxis] # insert neighbor axis new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear if base_shape is not None: @@ -56,7 +61,9 @@ def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_ new_displacements = particle_positions - new_indices * new_cell_size if base_shape is not None: - new_displacements -= jnp.rint(new_displacements / grid_length) * grid_length # also abs(new_displacements) < new_cell_size is expected + new_displacements -= jnp.rint( + new_displacements / grid_length + ) * grid_length # also abs(new_displacements) < new_cell_size is expected new_indices = new_indices.astype(base_indices.dtype) new_displacements = new_displacements.astype(displacements.dtype) diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 2b2bcab..e34d584 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,9 +1,7 @@ - import jax.numpy as jnp import jax_cosmo as jc -from jaxpm.distributed import (fft3d, ifft3d, - normal_field) +from jaxpm.distributed import fft3d, ifft3d, normal_field from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second, growth_rate, growth_rate_second) from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, @@ -27,7 +25,8 @@ def pm_forces(positions, mesh_shape = delta.shape if paint_absolute_pos: - paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape , device=sharding), + paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape, + device=sharding), pos, halo_size=halo_size, sharding=sharding) @@ -72,7 +71,8 @@ def lpt(cosmo, """ paint_absolute_pos = particles is not None if particles is None: - particles = jnp.zeros_like(initial_conditions , shape=(*initial_conditions.shape , 3)) + particles = jnp.zeros_like(initial_conditions, + shape=(*initial_conditions.shape, 3)) a = jnp.atleast_1d(a) E = jnp.sqrt(jc.background.Esqr(cosmo, a)) @@ -172,10 +172,11 @@ def make_ode_fn(mesh_shape, return nbody_ode -def make_diffrax_ode(cosmo, mesh_shape, - paint_absolute_pos=True, - halo_size=0, - sharding=None): +def make_diffrax_ode(cosmo, + mesh_shape, + paint_absolute_pos=True, + halo_size=0, + sharding=None): def nbody_ode(a, state, args): """ @@ -199,6 +200,7 @@ def make_diffrax_ode(cosmo, mesh_shape, return nbody_ode + def pgd_correction(pos, mesh_shape, params): """ improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 4e140e5..96faeea 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -5,7 +5,6 @@ import numpy as np from jax.scipy.stats import norm from scipy.special import legendre - __all__ = [ 'power_spectrum', 'transfer', 'coherence', 'pktranscoh', 'cross_correlation_coefficients', 'gaussian_smoothing' diff --git a/notebooks/05-MultiHost_PM.py b/notebooks/05-MultiHost_PM.py index 3e223b7..da3964e 100644 --- a/notebooks/05-MultiHost_PM.py +++ b/notebooks/05-MultiHost_PM.py @@ -18,7 +18,7 @@ import numpy as np from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt, diffeqsolve) from jax.experimental.mesh_utils import create_device_mesh -from jax.experimental.multihost_utils import (process_allgather) +from jax.experimental.multihost_utils import process_allgather from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P diff --git a/tests/conftest.py b/tests/conftest.py index 94c3fb2..6d91684 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,65 +1,59 @@ # Parameterized fixture for mesh_shape import os + import pytest + os.environ["EQX_ON_ERROR"] = "nan" setup_done = False on_cluster = False def is_on_cluster(): - global on_cluster - return on_cluster + global on_cluster + return on_cluster + def initialize_distributed(): - global setup_done - global on_cluster - if not setup_done: - if "SLURM_JOB_ID" in os.environ: - on_cluster = True - print("Running on cluster") - import jax - jax.distributed.initialize() - setup_done = True - on_cluster = True - else: - print("Running locally") - setup_done = True - on_cluster = False - os.environ["JAX_PLATFORM_NAME"] = "cpu" - os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" - import jax - - -@pytest.fixture(scope="session", autouse=True) -def setup_and_teardown_session(): - # Code to run at the start of the session - print("Starting session...") - initialize_distributed() - # Setup code here - # e.g., connecting to a database, initializing some resources, etc. - + global setup_done + global on_cluster + if not setup_done: + if "SLURM_JOB_ID" in os.environ: + on_cluster = True + print("Running on cluster") + import jax + jax.distributed.initialize() + setup_done = True + on_cluster = True + else: + print("Running locally") + setup_done = True + on_cluster = False + os.environ["JAX_PLATFORM_NAME"] = "cpu" + os.environ[ + "XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" + import jax @pytest.fixture( scope="session", params=[ - ((64, 64, 64) , (512., 512., 512.)), # BOX - ((64, 64, 128) , (256. , 256. , 512.)), # RECTANGULAR + ((32, 32, 32), (256., 256., 256.)), # BOX + ((32, 32, 64), (256., 256., 512.)), # RECTANGULAR ]) def simulation_config(request): return request.param -@pytest.fixture(scope="session", params=[0.1 , 0.5 , 0.8]) +@pytest.fixture(scope="session", params=[0.1, 0.5, 0.8]) def lpt_scale_factor(request): return request.param - @pytest.fixture(scope="session") def cosmo(): - from jax_cosmo import Cosmology from functools import partial + + from jax_cosmo import Cosmology Planck18 = partial( Cosmology, # Omega_m = 0.3111 @@ -85,9 +79,9 @@ def particle_mesh(simulation_config): @pytest.fixture(scope="session") def fpm_initial_conditions(cosmo, particle_mesh): - from jax import numpy as jnp import jax_cosmo as jc import numpy as np + from jax import numpy as jnp # Generate initial particle positions grid = particle_mesh.generate_uniform_particle_grid(shift=0).astype( @@ -117,7 +111,8 @@ def initial_conditions(fpm_initial_conditions): @pytest.fixture(scope="session") def solver(cosmo, particle_mesh): - from fastpm.core import Solver, Cosmology as FastPMCosmology + from fastpm.core import Cosmology as FastPMCosmology + from fastpm.core import Solver ref_cosmo = FastPMCosmology(cosmo) return Solver(particle_mesh, ref_cosmo, B=1) @@ -150,8 +145,8 @@ def fpm_lpt2_field(fpm_lpt2, particle_mesh): @pytest.fixture(scope="session") def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor): - from fastpm.core import leapfrog import numpy as np + from fastpm.core import leapfrog if lpt_scale_factor == 0.8: pytest.skip("Do not run nbody simulation from scale factor 0.8") @@ -166,8 +161,8 @@ def nbody_from_lpt1(solver, fpm_lpt1, particle_mesh, lpt_scale_factor): @pytest.fixture(scope="session") def nbody_from_lpt2(solver, fpm_lpt2, particle_mesh, lpt_scale_factor): - from fastpm.core import leapfrog import numpy as np + from fastpm.core import leapfrog if lpt_scale_factor == 0.8: pytest.skip("Do not run nbody simulation from scale factor 0.8") diff --git a/tests/helpers.py b/tests/helpers.py index 3369a8c..0b85161 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,10 +1,13 @@ import jax.numpy as jnp -def MSE(x , y): + +def MSE(x, y): return jnp.mean((x - y)**2) -def MSE_3D(x , y): + +def MSE_3D(x, y): return ((x - y)**2).mean(axis=0) -def MSRE(x , y): - return jnp.mean(((x - y)/ y)**2) \ No newline at end of file + +def MSRE(x, y): + return jnp.mean(((x - y) / y)**2)