diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index b8d888e..44177e9 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -82,7 +82,7 @@ def slice_unpad_impl(x, pad_width): def slice_pad(x, pad_width, sharding): gpu_mesh = sharding.mesh if sharding is not None else None - if not gpu_mesh is None and not (gpu_mesh.empty) and ( + if gpu_mesh is not None and not (gpu_mesh.empty) and ( pad_width[0][0] > 0 or pad_width[1][0] > 0): assert sharding is not None spec = sharding.spec @@ -96,7 +96,7 @@ def slice_pad(x, pad_width, sharding): def slice_unpad(x, pad_width, sharding): mesh = sharding.mesh if sharding is not None else None - if not mesh is None and not (mesh.empty) and (pad_width[0][0] > 0 + if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0 or pad_width[1][0] > 0): assert sharding is not None spec = sharding.spec @@ -122,20 +122,6 @@ def get_local_shape(mesh_shape, sharding=None): ] -def zeros(mesh_shape, sharding=None): - gpu_mesh = sharding.mesh if sharding is not None else None - if not gpu_mesh is None and not (gpu_mesh.empty): - local_mesh_shape = get_local_shape(mesh_shape, sharding) - spec = sharding.spec - return shard_map( - partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), - mesh=gpu_mesh, - in_specs=(), - out_specs=spec)() # yapf: disable - else: - return jnp.zeros(mesh_shape) - - def __axis_names(spec): if len(spec) == 1: x_axis, = spec @@ -158,7 +144,7 @@ def __axis_names(spec): def uniform_particles(mesh_shape, sharding=None): gpu_mesh = sharding.mesh if sharding is not None else None - if not gpu_mesh is None and not (gpu_mesh.empty): + if gpu_mesh is not None and not (gpu_mesh.empty): local_mesh_shape = get_local_shape(mesh_shape, sharding) spec = sharding.spec x_axis, y_axis, single_axis = __axis_names(spec) @@ -183,7 +169,7 @@ def uniform_particles(mesh_shape, sharding=None): def normal_field(mesh_shape, seed, sharding=None): """Generate a Gaussian random field with the given power spectrum.""" gpu_mesh = sharding.mesh if sharding is not None else None - if not gpu_mesh is None and not (gpu_mesh.empty): + if gpu_mesh is not None and not (gpu_mesh.empty): local_mesh_shape = get_local_shape(mesh_shape, sharding) size = jax.device_count() diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 170f5e9..7672093 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,5 +1,4 @@ import jax.numpy as jnp -import jax_cosmo as jc import numpy as np from jax.lib.xla_client import FftType from jax.sharding import PartitionSpec as P diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 76bb9b6..aec41e8 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -204,7 +204,7 @@ def cic_paint_dx(displacements, return grid_mesh -def cic_read_dx_impl(grid_mesh, halo_size): +def cic_read_dx_impl(grid_mesh, disp, halo_size): halo_x, _ = halo_size[0] halo_y, _ = halo_size[1] @@ -220,14 +220,15 @@ def cic_read_dx_impl(grid_mesh, halo_size): pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1) pmid = pmid.reshape([-1, 3]) + disp = disp.reshape([-1, 3]) - return gather(pmid, jnp.zeros_like(pmid), + return gather(pmid, disp, grid_mesh).reshape(original_shape) -@partial(jax.jit, static_argnums=(1, 2)) -def cic_read_dx(grid_mesh, halo_size=0, sharding=None): - # return mesh +@partial(jax.jit, static_argnums=(2, 3)) +def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None): + halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding) grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding) grid_mesh = halo_exchange(grid_mesh, @@ -238,7 +239,7 @@ def cic_read_dx(grid_mesh, halo_size=0, sharding=None): displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size), gpu_mesh=gpu_mesh, in_specs=(spec), - out_specs=spec)(grid_mesh) + out_specs=spec)(grid_mesh , disp) return displacements diff --git a/jaxpm/painting_utils.py b/jaxpm/painting_utils.py index 8742ccd..916b457 100644 --- a/jaxpm/painting_utils.py +++ b/jaxpm/painting_utils.py @@ -25,72 +25,71 @@ def _chunk_split(ptcl_num, chunk_size, *arrays): return remainder, chunks -def enmesh(i1, d1, a1, s1, b12, a2, s2): +def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape): """Multilinear enmeshing.""" - i1 = jnp.asarray(i1) - d1 = jnp.asarray(d1) + base_indices = jnp.asarray(base_indices) + displacements = jnp.asarray(displacements) with jax.experimental.enable_x64(): - a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1, - dtype=d1.dtype) - if s1 is not None: - s1 = jnp.array(s1, dtype=i1.dtype) - b12 = jnp.float64(b12) - if a2 is not None: - a2 = jnp.float64(a2) - if s2 is not None: - s2 = jnp.array(s2, dtype=i1.dtype) + cell_size = jnp.float64(cell_size) if new_cell_size is not None else jnp.array(cell_size, dtype=displacements.dtype) + if base_shape is not None: + base_shape = jnp.array(base_shape, dtype=base_indices.dtype) + offset = jnp.float64(offset) + if new_cell_size is not None: + new_cell_size = jnp.float64(new_cell_size) + if new_shape is not None: + new_shape = jnp.array(new_shape, dtype=base_indices.dtype) - dim = i1.shape[1] - neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >> - jnp.arange(dim, dtype=i1.dtype)) & 1 + spatial_dim = base_indices.shape[1] + neighbor_offsets = (jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >> + jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1 - if a2 is not None: - P = i1 * a1 + d1 - b12 - P = P[:, jnp.newaxis] # insert neighbor axis - i2 = P + neighbors * a2 # multilinear + if new_cell_size is not None: + particle_positions = base_indices * cell_size + displacements - offset + particle_positions = particle_positions[:, jnp.newaxis] # insert neighbor axis + new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear - if s1 is not None: - L = s1 * a1 - i2 %= L + if base_shape is not None: + grid_length = base_shape * cell_size + new_indices %= grid_length - i2 //= a2 - d2 = P - i2 * a2 + new_indices //= new_cell_size + new_displacements = particle_positions - new_indices * new_cell_size - if s1 is not None: - d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected + if base_shape is not None: + new_displacements -= jnp.rint(new_displacements / grid_length) * grid_length # also abs(new_displacements) < new_cell_size is expected - i2 = i2.astype(i1.dtype) - d2 = d2.astype(d1.dtype) - a2 = a2.astype(d1.dtype) + new_indices = new_indices.astype(base_indices.dtype) + new_displacements = new_displacements.astype(displacements.dtype) + new_cell_size = new_cell_size.astype(displacements.dtype) - d2 /= a2 + new_displacements /= new_cell_size else: - i12, d12 = jnp.divmod(b12, a1) - i1 -= i12.astype(i1.dtype) - d1 -= d12.astype(d1.dtype) + offset_indices, offset_displacements = jnp.divmod(offset, cell_size) + base_indices -= offset_indices.astype(base_indices.dtype) + displacements -= offset_displacements.astype(displacements.dtype) # insert neighbor axis - i1 = i1[:, jnp.newaxis] - d1 = d1[:, jnp.newaxis] + base_indices = base_indices[:, jnp.newaxis] + displacements = displacements[:, jnp.newaxis] # multilinear - d1 /= a1 - i2 = jnp.floor(d1).astype(i1.dtype) - i2 += neighbors - d2 = d1 - i2 - i2 += i1 + displacements /= cell_size + new_indices = jnp.floor(displacements).astype(base_indices.dtype) + new_indices += neighbor_offsets + new_displacements = displacements - new_indices + new_indices += base_indices - if s1 is not None: - i2 %= s1 + if base_shape is not None: + new_indices %= base_shape - f2 = 1 - jnp.abs(d2) + weights = 1 - jnp.abs(new_displacements) - if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None - i2 = jnp.where(i2 < 0, s2, i2) + if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None + new_indices = jnp.where(new_indices < 0, new_shape, new_indices) - f2 = f2.prod(axis=-1) + weights = weights.prod(axis=-1) - return i2, f2 + return new_indices, weights def _scatter_chunk(carry, chunk): @@ -138,7 +137,7 @@ def _chunk_cat(remainder_array, chunked_array): return array -def gather(pmid, disp, mesh, chunk_size=2**24, val=1, offset=0, cell_size=1.): +def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.): ptcl_num, spatial_ndim = pmid.shape mesh = jnp.asarray(mesh) diff --git a/jaxpm/plotting.py b/jaxpm/plotting.py index 4819207..9fe4d8e 100644 --- a/jaxpm/plotting.py +++ b/jaxpm/plotting.py @@ -1,4 +1,3 @@ -import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 262b916..2b2bcab 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,11 +1,9 @@ -from functools import partial import jax.numpy as jnp import jax_cosmo as jc -from jax.sharding import PartitionSpec as P -from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d, - normal_field, zeros) +from jaxpm.distributed import (fft3d, ifft3d, + normal_field) from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second, growth_rate, growth_rate_second) from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, @@ -29,17 +27,17 @@ def pm_forces(positions, mesh_shape = delta.shape if paint_absolute_pos: - paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding), - x, - halo_size=halo_size, - sharding=sharding) - read_fn = lambda x: cic_read( - x, positions, halo_size=halo_size, sharding=sharding) + paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape , device=sharding), + pos, + halo_size=halo_size, + sharding=sharding) + read_fn = lambda grid_mesh, pos: cic_read( + grid_mesh, pos, halo_size=halo_size, sharding=sharding) else: - paint_fn = partial(cic_paint_dx, - halo_size=halo_size, - sharding=sharding) - read_fn = partial(cic_read_dx, halo_size=halo_size, sharding=sharding) + paint_fn = lambda disp: cic_paint_dx( + disp, halo_size=halo_size, sharding=sharding) + read_fn = lambda grid_mesh, disp: cic_read_dx( + grid_mesh, disp, halo_size=halo_size, sharding=sharding) if delta is None: field = paint_fn(positions) @@ -55,7 +53,7 @@ def pm_forces(positions, kvec, r_split=r_split) # Computes gravitational forces forces = jnp.stack([ - read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k), + read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions ) for i in range(3)], axis=-1) # yapf: disable return forces @@ -73,6 +71,8 @@ def lpt(cosmo, e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258) """ paint_absolute_pos = particles is not None + if particles is None: + particles = jnp.zeros_like(initial_conditions , shape=(*initial_conditions.shape , 3)) a = jnp.atleast_1d(a) E = jnp.sqrt(jc.background.Esqr(cosmo, a)) @@ -167,25 +167,27 @@ def make_ode_fn(mesh_shape, # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces - #dpos = dpos if not paint_absolute_pos else dpos + pos - return dpos, dvel return nbody_ode -def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None): +def make_diffrax_ode(cosmo, mesh_shape, + paint_absolute_pos=True, + halo_size=0, + sharding=None): def nbody_ode(a, state, args): """ - State is an array [position, velocities] - - Compatible with [Diffrax API](https://docs.kidger.site/diffrax/) + state is a tuple (position, velocities) """ pos, vel = state - forces = pm_forces( - pos, mesh_shape, halo_size=halo_size, - sharding=sharding) * 1.5 * cosmo.Omega_m + + forces = pm_forces(pos, + mesh_shape=mesh_shape, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel @@ -197,7 +199,6 @@ def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None): return nbody_ode - def pgd_correction(pos, mesh_shape, params): """ improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, diff --git a/jaxpm/utils.py b/jaxpm/utils.py index 659ab3f..4e140e5 100644 --- a/jaxpm/utils.py +++ b/jaxpm/utils.py @@ -5,7 +5,6 @@ import numpy as np from jax.scipy.stats import norm from scipy.special import legendre -from jaxpm.growth import growth_factor, growth_rate __all__ = [ 'power_spectrum', 'transfer', 'coherence', 'pktranscoh',