mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-12 06:00:54 +00:00
update code
This commit is contained in:
parent
e0c118a540
commit
21373b89ee
7 changed files with 84 additions and 100 deletions
|
@ -82,7 +82,7 @@ def slice_unpad_impl(x, pad_width):
|
||||||
|
|
||||||
def slice_pad(x, pad_width, sharding):
|
def slice_pad(x, pad_width, sharding):
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
if not gpu_mesh is None and not (gpu_mesh.empty) and (
|
if gpu_mesh is not None and not (gpu_mesh.empty) and (
|
||||||
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
pad_width[0][0] > 0 or pad_width[1][0] > 0):
|
||||||
assert sharding is not None
|
assert sharding is not None
|
||||||
spec = sharding.spec
|
spec = sharding.spec
|
||||||
|
@ -96,7 +96,7 @@ def slice_pad(x, pad_width, sharding):
|
||||||
|
|
||||||
def slice_unpad(x, pad_width, sharding):
|
def slice_unpad(x, pad_width, sharding):
|
||||||
mesh = sharding.mesh if sharding is not None else None
|
mesh = sharding.mesh if sharding is not None else None
|
||||||
if not mesh is None and not (mesh.empty) and (pad_width[0][0] > 0
|
if mesh is not None and not (mesh.empty) and (pad_width[0][0] > 0
|
||||||
or pad_width[1][0] > 0):
|
or pad_width[1][0] > 0):
|
||||||
assert sharding is not None
|
assert sharding is not None
|
||||||
spec = sharding.spec
|
spec = sharding.spec
|
||||||
|
@ -122,20 +122,6 @@ def get_local_shape(mesh_shape, sharding=None):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def zeros(mesh_shape, sharding=None):
|
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
|
||||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
|
||||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
|
||||||
spec = sharding.spec
|
|
||||||
return shard_map(
|
|
||||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
|
||||||
mesh=gpu_mesh,
|
|
||||||
in_specs=(),
|
|
||||||
out_specs=spec)() # yapf: disable
|
|
||||||
else:
|
|
||||||
return jnp.zeros(mesh_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def __axis_names(spec):
|
def __axis_names(spec):
|
||||||
if len(spec) == 1:
|
if len(spec) == 1:
|
||||||
x_axis, = spec
|
x_axis, = spec
|
||||||
|
@ -158,7 +144,7 @@ def __axis_names(spec):
|
||||||
def uniform_particles(mesh_shape, sharding=None):
|
def uniform_particles(mesh_shape, sharding=None):
|
||||||
|
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||||
spec = sharding.spec
|
spec = sharding.spec
|
||||||
x_axis, y_axis, single_axis = __axis_names(spec)
|
x_axis, y_axis, single_axis = __axis_names(spec)
|
||||||
|
@ -183,7 +169,7 @@ def uniform_particles(mesh_shape, sharding=None):
|
||||||
def normal_field(mesh_shape, seed, sharding=None):
|
def normal_field(mesh_shape, seed, sharding=None):
|
||||||
"""Generate a Gaussian random field with the given power spectrum."""
|
"""Generate a Gaussian random field with the given power spectrum."""
|
||||||
gpu_mesh = sharding.mesh if sharding is not None else None
|
gpu_mesh = sharding.mesh if sharding is not None else None
|
||||||
if not gpu_mesh is None and not (gpu_mesh.empty):
|
if gpu_mesh is not None and not (gpu_mesh.empty):
|
||||||
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
local_mesh_shape = get_local_shape(mesh_shape, sharding)
|
||||||
|
|
||||||
size = jax.device_count()
|
size = jax.device_count()
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jax.lib.xla_client import FftType
|
from jax.lib.xla_client import FftType
|
||||||
from jax.sharding import PartitionSpec as P
|
from jax.sharding import PartitionSpec as P
|
||||||
|
|
|
@ -204,7 +204,7 @@ def cic_paint_dx(displacements,
|
||||||
return grid_mesh
|
return grid_mesh
|
||||||
|
|
||||||
|
|
||||||
def cic_read_dx_impl(grid_mesh, halo_size):
|
def cic_read_dx_impl(grid_mesh, disp, halo_size):
|
||||||
|
|
||||||
halo_x, _ = halo_size[0]
|
halo_x, _ = halo_size[0]
|
||||||
halo_y, _ = halo_size[1]
|
halo_y, _ = halo_size[1]
|
||||||
|
@ -220,14 +220,15 @@ def cic_read_dx_impl(grid_mesh, halo_size):
|
||||||
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
pmid = jnp.stack([a + halo_x, b + halo_y, c], axis=-1)
|
||||||
|
|
||||||
pmid = pmid.reshape([-1, 3])
|
pmid = pmid.reshape([-1, 3])
|
||||||
|
disp = disp.reshape([-1, 3])
|
||||||
|
|
||||||
return gather(pmid, jnp.zeros_like(pmid),
|
return gather(pmid, disp,
|
||||||
grid_mesh).reshape(original_shape)
|
grid_mesh).reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(1, 2))
|
@partial(jax.jit, static_argnums=(2, 3))
|
||||||
def cic_read_dx(grid_mesh, halo_size=0, sharding=None):
|
def cic_read_dx(grid_mesh,disp , halo_size=0, sharding=None):
|
||||||
# return mesh
|
|
||||||
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
halo_size, halo_extents = get_halo_size(halo_size, sharding=sharding)
|
||||||
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
grid_mesh = slice_pad(grid_mesh, halo_size, sharding=sharding)
|
||||||
grid_mesh = halo_exchange(grid_mesh,
|
grid_mesh = halo_exchange(grid_mesh,
|
||||||
|
@ -238,7 +239,7 @@ def cic_read_dx(grid_mesh, halo_size=0, sharding=None):
|
||||||
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
|
displacements = autoshmap(partial(cic_read_dx_impl, halo_size=halo_size),
|
||||||
gpu_mesh=gpu_mesh,
|
gpu_mesh=gpu_mesh,
|
||||||
in_specs=(spec),
|
in_specs=(spec),
|
||||||
out_specs=spec)(grid_mesh)
|
out_specs=spec)(grid_mesh , disp)
|
||||||
|
|
||||||
return displacements
|
return displacements
|
||||||
|
|
||||||
|
|
|
@ -25,72 +25,71 @@ def _chunk_split(ptcl_num, chunk_size, *arrays):
|
||||||
return remainder, chunks
|
return remainder, chunks
|
||||||
|
|
||||||
|
|
||||||
def enmesh(i1, d1, a1, s1, b12, a2, s2):
|
def enmesh(base_indices, displacements, cell_size, base_shape, offset, new_cell_size, new_shape):
|
||||||
"""Multilinear enmeshing."""
|
"""Multilinear enmeshing."""
|
||||||
i1 = jnp.asarray(i1)
|
base_indices = jnp.asarray(base_indices)
|
||||||
d1 = jnp.asarray(d1)
|
displacements = jnp.asarray(displacements)
|
||||||
with jax.experimental.enable_x64():
|
with jax.experimental.enable_x64():
|
||||||
a1 = jnp.float64(a1) if a2 is not None else jnp.array(a1,
|
cell_size = jnp.float64(cell_size) if new_cell_size is not None else jnp.array(cell_size, dtype=displacements.dtype)
|
||||||
dtype=d1.dtype)
|
if base_shape is not None:
|
||||||
if s1 is not None:
|
base_shape = jnp.array(base_shape, dtype=base_indices.dtype)
|
||||||
s1 = jnp.array(s1, dtype=i1.dtype)
|
offset = jnp.float64(offset)
|
||||||
b12 = jnp.float64(b12)
|
if new_cell_size is not None:
|
||||||
if a2 is not None:
|
new_cell_size = jnp.float64(new_cell_size)
|
||||||
a2 = jnp.float64(a2)
|
if new_shape is not None:
|
||||||
if s2 is not None:
|
new_shape = jnp.array(new_shape, dtype=base_indices.dtype)
|
||||||
s2 = jnp.array(s2, dtype=i1.dtype)
|
|
||||||
|
|
||||||
dim = i1.shape[1]
|
spatial_dim = base_indices.shape[1]
|
||||||
neighbors = (jnp.arange(2**dim, dtype=i1.dtype)[:, jnp.newaxis] >>
|
neighbor_offsets = (jnp.arange(2**spatial_dim, dtype=base_indices.dtype)[:, jnp.newaxis] >>
|
||||||
jnp.arange(dim, dtype=i1.dtype)) & 1
|
jnp.arange(spatial_dim, dtype=base_indices.dtype)) & 1
|
||||||
|
|
||||||
if a2 is not None:
|
if new_cell_size is not None:
|
||||||
P = i1 * a1 + d1 - b12
|
particle_positions = base_indices * cell_size + displacements - offset
|
||||||
P = P[:, jnp.newaxis] # insert neighbor axis
|
particle_positions = particle_positions[:, jnp.newaxis] # insert neighbor axis
|
||||||
i2 = P + neighbors * a2 # multilinear
|
new_indices = particle_positions + neighbor_offsets * new_cell_size # multilinear
|
||||||
|
|
||||||
if s1 is not None:
|
if base_shape is not None:
|
||||||
L = s1 * a1
|
grid_length = base_shape * cell_size
|
||||||
i2 %= L
|
new_indices %= grid_length
|
||||||
|
|
||||||
i2 //= a2
|
new_indices //= new_cell_size
|
||||||
d2 = P - i2 * a2
|
new_displacements = particle_positions - new_indices * new_cell_size
|
||||||
|
|
||||||
if s1 is not None:
|
if base_shape is not None:
|
||||||
d2 -= jnp.rint(d2 / L) * L # also abs(d2) < a2 is expected
|
new_displacements -= jnp.rint(new_displacements / grid_length) * grid_length # also abs(new_displacements) < new_cell_size is expected
|
||||||
|
|
||||||
i2 = i2.astype(i1.dtype)
|
new_indices = new_indices.astype(base_indices.dtype)
|
||||||
d2 = d2.astype(d1.dtype)
|
new_displacements = new_displacements.astype(displacements.dtype)
|
||||||
a2 = a2.astype(d1.dtype)
|
new_cell_size = new_cell_size.astype(displacements.dtype)
|
||||||
|
|
||||||
d2 /= a2
|
new_displacements /= new_cell_size
|
||||||
else:
|
else:
|
||||||
i12, d12 = jnp.divmod(b12, a1)
|
offset_indices, offset_displacements = jnp.divmod(offset, cell_size)
|
||||||
i1 -= i12.astype(i1.dtype)
|
base_indices -= offset_indices.astype(base_indices.dtype)
|
||||||
d1 -= d12.astype(d1.dtype)
|
displacements -= offset_displacements.astype(displacements.dtype)
|
||||||
|
|
||||||
# insert neighbor axis
|
# insert neighbor axis
|
||||||
i1 = i1[:, jnp.newaxis]
|
base_indices = base_indices[:, jnp.newaxis]
|
||||||
d1 = d1[:, jnp.newaxis]
|
displacements = displacements[:, jnp.newaxis]
|
||||||
|
|
||||||
# multilinear
|
# multilinear
|
||||||
d1 /= a1
|
displacements /= cell_size
|
||||||
i2 = jnp.floor(d1).astype(i1.dtype)
|
new_indices = jnp.floor(displacements).astype(base_indices.dtype)
|
||||||
i2 += neighbors
|
new_indices += neighbor_offsets
|
||||||
d2 = d1 - i2
|
new_displacements = displacements - new_indices
|
||||||
i2 += i1
|
new_indices += base_indices
|
||||||
|
|
||||||
if s1 is not None:
|
if base_shape is not None:
|
||||||
i2 %= s1
|
new_indices %= base_shape
|
||||||
|
|
||||||
f2 = 1 - jnp.abs(d2)
|
weights = 1 - jnp.abs(new_displacements)
|
||||||
|
|
||||||
if s1 is None and s2 is not None: # all i2 >= 0 if s1 is not None
|
if base_shape is None and new_shape is not None: # all new_indices >= 0 if base_shape is not None
|
||||||
i2 = jnp.where(i2 < 0, s2, i2)
|
new_indices = jnp.where(new_indices < 0, new_shape, new_indices)
|
||||||
|
|
||||||
f2 = f2.prod(axis=-1)
|
weights = weights.prod(axis=-1)
|
||||||
|
|
||||||
return i2, f2
|
return new_indices, weights
|
||||||
|
|
||||||
|
|
||||||
def _scatter_chunk(carry, chunk):
|
def _scatter_chunk(carry, chunk):
|
||||||
|
@ -138,7 +137,7 @@ def _chunk_cat(remainder_array, chunked_array):
|
||||||
return array
|
return array
|
||||||
|
|
||||||
|
|
||||||
def gather(pmid, disp, mesh, chunk_size=2**24, val=1, offset=0, cell_size=1.):
|
def gather(pmid, disp, mesh, chunk_size=2**24, val=0, offset=0, cell_size=1.):
|
||||||
ptcl_num, spatial_ndim = pmid.shape
|
ptcl_num, spatial_ndim = pmid.shape
|
||||||
|
|
||||||
mesh = jnp.asarray(mesh)
|
mesh = jnp.asarray(mesh)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import jax.numpy as jnp
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
51
jaxpm/pm.py
51
jaxpm/pm.py
|
@ -1,11 +1,9 @@
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
from jax.sharding import PartitionSpec as P
|
|
||||||
|
|
||||||
from jaxpm.distributed import (autoshmap, fft3d, get_local_shape, ifft3d,
|
from jaxpm.distributed import (fft3d, ifft3d,
|
||||||
normal_field, zeros)
|
normal_field)
|
||||||
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
from jaxpm.growth import (dGf2a, dGfa, growth_factor, growth_factor_second,
|
||||||
growth_rate, growth_rate_second)
|
growth_rate, growth_rate_second)
|
||||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel,
|
||||||
|
@ -29,17 +27,17 @@ def pm_forces(positions,
|
||||||
mesh_shape = delta.shape
|
mesh_shape = delta.shape
|
||||||
|
|
||||||
if paint_absolute_pos:
|
if paint_absolute_pos:
|
||||||
paint_fn = lambda x: cic_paint(zeros(mesh_shape, sharding),
|
paint_fn = lambda pos: cic_paint(jnp.zeros(shape=mesh_shape , device=sharding),
|
||||||
x,
|
pos,
|
||||||
halo_size=halo_size,
|
halo_size=halo_size,
|
||||||
sharding=sharding)
|
sharding=sharding)
|
||||||
read_fn = lambda x: cic_read(
|
read_fn = lambda grid_mesh, pos: cic_read(
|
||||||
x, positions, halo_size=halo_size, sharding=sharding)
|
grid_mesh, pos, halo_size=halo_size, sharding=sharding)
|
||||||
else:
|
else:
|
||||||
paint_fn = partial(cic_paint_dx,
|
paint_fn = lambda disp: cic_paint_dx(
|
||||||
halo_size=halo_size,
|
disp, halo_size=halo_size, sharding=sharding)
|
||||||
sharding=sharding)
|
read_fn = lambda grid_mesh, disp: cic_read_dx(
|
||||||
read_fn = partial(cic_read_dx, halo_size=halo_size, sharding=sharding)
|
grid_mesh, disp, halo_size=halo_size, sharding=sharding)
|
||||||
|
|
||||||
if delta is None:
|
if delta is None:
|
||||||
field = paint_fn(positions)
|
field = paint_fn(positions)
|
||||||
|
@ -55,7 +53,7 @@ def pm_forces(positions,
|
||||||
kvec, r_split=r_split)
|
kvec, r_split=r_split)
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
forces = jnp.stack([
|
forces = jnp.stack([
|
||||||
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),
|
read_fn(ifft3d(-gradient_kernel(kvec, i) * pot_k),positions
|
||||||
) for i in range(3)], axis=-1) # yapf: disable
|
) for i in range(3)], axis=-1) # yapf: disable
|
||||||
|
|
||||||
return forces
|
return forces
|
||||||
|
@ -73,6 +71,8 @@ def lpt(cosmo,
|
||||||
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
e.g. Eq. 2 and 3 [Jenkins2010](https://arxiv.org/pdf/0910.0258)
|
||||||
"""
|
"""
|
||||||
paint_absolute_pos = particles is not None
|
paint_absolute_pos = particles is not None
|
||||||
|
if particles is None:
|
||||||
|
particles = jnp.zeros_like(initial_conditions , shape=(*initial_conditions.shape , 3))
|
||||||
|
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
E = jnp.sqrt(jc.background.Esqr(cosmo, a))
|
||||||
|
@ -167,25 +167,27 @@ def make_ode_fn(mesh_shape,
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
|
||||||
|
|
||||||
#dpos = dpos if not paint_absolute_pos else dpos + pos
|
|
||||||
|
|
||||||
return dpos, dvel
|
return dpos, dvel
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
def make_diffrax_ode(cosmo, mesh_shape,
|
||||||
|
paint_absolute_pos=True,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None):
|
||||||
|
|
||||||
def nbody_ode(a, state, args):
|
def nbody_ode(a, state, args):
|
||||||
"""
|
"""
|
||||||
State is an array [position, velocities]
|
state is a tuple (position, velocities)
|
||||||
|
|
||||||
Compatible with [Diffrax API](https://docs.kidger.site/diffrax/)
|
|
||||||
"""
|
"""
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
forces = pm_forces(
|
|
||||||
pos, mesh_shape, halo_size=halo_size,
|
forces = pm_forces(pos,
|
||||||
sharding=sharding) * 1.5 * cosmo.Omega_m
|
mesh_shape=mesh_shape,
|
||||||
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
# Computes the update of position (drift)
|
# Computes the update of position (drift)
|
||||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
@ -197,7 +199,6 @@ def get_ode_fn(cosmo, mesh_shape, halo_size=0, sharding=None):
|
||||||
|
|
||||||
return nbody_ode
|
return nbody_ode
|
||||||
|
|
||||||
|
|
||||||
def pgd_correction(pos, mesh_shape, params):
|
def pgd_correction(pos, mesh_shape, params):
|
||||||
"""
|
"""
|
||||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
|
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method,
|
||||||
|
|
|
@ -5,7 +5,6 @@ import numpy as np
|
||||||
from jax.scipy.stats import norm
|
from jax.scipy.stats import norm
|
||||||
from scipy.special import legendre
|
from scipy.special import legendre
|
||||||
|
|
||||||
from jaxpm.growth import growth_factor, growth_rate
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
'power_spectrum', 'transfer', 'coherence', 'pktranscoh',
|
||||||
|
|
Loading…
Add table
Reference in a new issue