mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 12:20:54 +00:00
Use new cic_paint with halo
This commit is contained in:
parent
5775a37550
commit
7501b5bc6d
2 changed files with 77 additions and 36 deletions
|
@ -10,12 +10,27 @@ except ImportError:
|
|||
print("jaxdecomp not installed. Distributed functions will not work.")
|
||||
distributed = False
|
||||
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax.experimental.shard_map import shard_map
|
||||
from functools import partial
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
# NOTE
|
||||
# This should not be used as a decorator
|
||||
# Must be used inside a function only
|
||||
# Example
|
||||
# BAD
|
||||
# @autoshmap
|
||||
# def foo():
|
||||
# pass
|
||||
# GOOD
|
||||
# def foo():
|
||||
# return autoshmap(foo_impl)()
|
||||
|
||||
|
||||
def autoshmap(f: Callable,
|
||||
in_specs: Specs,
|
||||
out_specs: Specs,
|
||||
|
@ -34,31 +49,43 @@ def fft3d(x):
|
|||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return jaxdecomp.pfft3d(x.astype(jnp.complex64))
|
||||
else:
|
||||
return jnp.fft.rfftn(x)
|
||||
|
||||
return jnp.fft.fftn(x.astype(jnp.complex64))
|
||||
|
||||
|
||||
def ifft3d(x):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return jaxdecomp.pifft3d(x).real
|
||||
else:
|
||||
return jnp.fft.irfftn(x)
|
||||
|
||||
def halo_exchange(x):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return jaxdecomp.halo_exchange(x)
|
||||
return jnp.fft.ifftn(x).real
|
||||
|
||||
|
||||
def get_halo_size(halo_size):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if mesh.empty:
|
||||
zero_ext = (0, 0, 0)
|
||||
zero_tuple = (0, 0)
|
||||
return (zero_tuple, zero_tuple, zero_tuple), zero_ext
|
||||
else:
|
||||
pdims = mesh.devices.shape
|
||||
halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size)
|
||||
halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size)
|
||||
|
||||
halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2
|
||||
halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2
|
||||
return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext, 0))
|
||||
|
||||
|
||||
def halo_exchange(x, halo_extents, halo_periods=(True, True, True)):
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if distributed and not (mesh.empty) and (halo_extents[0] > 0
|
||||
or halo_extents[1] > 0):
|
||||
return jaxdecomp.halo_exchange(x, halo_extents, halo_periods)
|
||||
else:
|
||||
return x
|
||||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x', 'y'), P()),
|
||||
out_specs=P('x', 'y'))
|
||||
def slice_pad_impl(x, pad_width):
|
||||
return jnp.pad(x, pad_width)
|
||||
|
||||
@partial(autoshmap,
|
||||
in_specs=(P('x', 'y'), P()),
|
||||
out_specs=P('x', 'y'))
|
||||
def slice_unpad_impl(x, pad_width):
|
||||
|
||||
halo_x, _ = pad_width[0]
|
||||
halo_y, _ = pad_width[0]
|
||||
|
||||
|
@ -68,17 +95,28 @@ def slice_unpad_impl(x, pad_width):
|
|||
# Apply corrections along y
|
||||
x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2])
|
||||
x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:])
|
||||
return x
|
||||
|
||||
return x[halo_x:-halo_x, halo_y:-halo_y, :]
|
||||
|
||||
|
||||
def slice_pad(x, pad_width):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return slice_pad_impl(x, pad_width)
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if distributed and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
or pad_width[1][0] > 0):
|
||||
return autoshmap((partial(jnp.pad, pad_width=pad_width)),
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
|
||||
def slice_unpad(x, pad_width):
|
||||
if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty):
|
||||
return slice_unpad_impl(x, pad_width)
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
if distributed and not (mesh.empty) and (pad_width[0][0] > 0
|
||||
or pad_width[1][0] > 0):
|
||||
return autoshmap(partial(slice_unpad_impl, pad_width=pad_width),
|
||||
in_specs=(P('x', 'y')),
|
||||
out_specs=P('x', 'y'))(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
|
31
jaxpm/pm.py
31
jaxpm/pm.py
|
@ -9,10 +9,10 @@ from jaxpm.distributed import autoshmap, fft3d, get_local_shape, ifft3d
|
|||
from jaxpm.growth import dGfa, growth_factor, growth_rate
|
||||
from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel,
|
||||
longrange_kernel)
|
||||
from jaxpm.painting import cic_paint, cic_read
|
||||
from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx
|
||||
|
||||
|
||||
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
||||
def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0):
|
||||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
|
@ -21,7 +21,7 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
|||
kvec = fftk(mesh_shape)
|
||||
|
||||
if delta is None:
|
||||
delta_k = fft3d(cic_paint(jnp.zeros(mesh_shape), positions))
|
||||
delta_k = fft3d(cic_paint_dx(positions, halo_size=0))
|
||||
else:
|
||||
delta_k = fft3d(delta)
|
||||
|
||||
|
@ -29,26 +29,28 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0):
|
|||
pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec,
|
||||
r_split=r_split)
|
||||
# Computes gravitational forces
|
||||
return jnp.stack([
|
||||
cic_read(ifft3d(gradient_kernel(kvec, i) * pot_k), positions)
|
||||
forces = jnp.stack([
|
||||
cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=0)
|
||||
for i in range(3)
|
||||
],
|
||||
axis=-1)
|
||||
axis=-1)
|
||||
|
||||
return forces
|
||||
|
||||
def lpt(cosmo, initial_conditions, a, particles_shape=None):
|
||||
def lpt(cosmo, initial_conditions, a, halo_size=0):
|
||||
"""
|
||||
Computes first order LPT displacement
|
||||
"""
|
||||
if particles_shape is None:
|
||||
particles_shape = initial_conditions.shape
|
||||
local_mesh_shape = get_local_shape(particles_shape)
|
||||
local_mesh_shape = get_local_shape(initial_conditions.shape) + (3, )
|
||||
displacement = autoshmap(
|
||||
partial(jnp.zeros, shape=local_mesh_shape+[3], dtype='float32'),
|
||||
partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'),
|
||||
in_specs=(),
|
||||
out_specs=P('x', 'y'))() # yapf: disable
|
||||
|
||||
initial_force = pm_forces(displacement, delta=initial_conditions)
|
||||
|
||||
initial_force = pm_forces(displacement,
|
||||
delta=initial_conditions,
|
||||
halo_size=halo_size)
|
||||
a = jnp.atleast_1d(a)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo,
|
||||
|
@ -80,7 +82,7 @@ def linear_field(mesh_shape, box_size, pk, seed):
|
|||
return field
|
||||
|
||||
|
||||
def make_ode_fn(mesh_shape):
|
||||
def make_ode_fn(mesh_shape, halo_size=0):
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
"""
|
||||
|
@ -88,7 +90,8 @@ def make_ode_fn(mesh_shape):
|
|||
"""
|
||||
pos, vel = state
|
||||
|
||||
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
forces = pm_forces(pos, mesh_shape=mesh_shape,
|
||||
halo_size=halo_size) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
|
Loading…
Add table
Reference in a new issue