diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index b343fc8..83d5cb9 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -10,12 +10,27 @@ except ImportError: print("jaxdecomp not installed. Distributed functions will not work.") distributed = False +from functools import partial + +import jax import jax.numpy as jnp from jax._src import mesh as mesh_lib from jax.experimental.shard_map import shard_map -from functools import partial from jax.sharding import PartitionSpec as P +# NOTE +# This should not be used as a decorator +# Must be used inside a function only +# Example +# BAD +# @autoshmap +# def foo(): +# pass +# GOOD +# def foo(): +# return autoshmap(foo_impl)() + + def autoshmap(f: Callable, in_specs: Specs, out_specs: Specs, @@ -34,31 +49,43 @@ def fft3d(x): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): return jaxdecomp.pfft3d(x.astype(jnp.complex64)) else: - return jnp.fft.rfftn(x) - + return jnp.fft.fftn(x.astype(jnp.complex64)) + def ifft3d(x): if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): return jaxdecomp.pifft3d(x).real else: - return jnp.fft.irfftn(x) - -def halo_exchange(x): - if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): - return jaxdecomp.halo_exchange(x) + return jnp.fft.ifftn(x).real + + +def get_halo_size(halo_size): + mesh = mesh_lib.thread_resources.env.physical_mesh + if mesh.empty: + zero_ext = (0, 0, 0) + zero_tuple = (0, 0) + return (zero_tuple, zero_tuple, zero_tuple), zero_ext + else: + pdims = mesh.devices.shape + halo_x = (0, 0) if pdims[0] == 1 else (halo_size, halo_size) + halo_y = (0, 0) if pdims[1] == 1 else (halo_size, halo_size) + + halo_x_ext = 0 if pdims[0] == 1 else halo_size // 2 + halo_y_ext = 0 if pdims[1] == 1 else halo_size // 2 + return ((halo_x, halo_y, (0, 0)), (halo_x_ext, halo_y_ext, 0)) + + +def halo_exchange(x, halo_extents, halo_periods=(True, True, True)): + mesh = mesh_lib.thread_resources.env.physical_mesh + if distributed and not (mesh.empty) and (halo_extents[0] > 0 + or halo_extents[1] > 0): + return jaxdecomp.halo_exchange(x, halo_extents, halo_periods) else: return x -@partial(autoshmap, - in_specs=(P('x', 'y'), P()), - out_specs=P('x', 'y')) -def slice_pad_impl(x, pad_width): - return jnp.pad(x, pad_width) -@partial(autoshmap, - in_specs=(P('x', 'y'), P()), - out_specs=P('x', 'y')) def slice_unpad_impl(x, pad_width): + halo_x, _ = pad_width[0] halo_y, _ = pad_width[0] @@ -68,17 +95,28 @@ def slice_unpad_impl(x, pad_width): # Apply corrections along y x = x.at[:, halo_y:halo_y + halo_y // 2].add(x[:, :halo_y // 2]) x = x.at[:, -(halo_y + halo_y // 2):-halo_y].add(x[:, -halo_y // 2:]) - return x + + return x[halo_x:-halo_x, halo_y:-halo_y, :] + def slice_pad(x, pad_width): - if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): - return slice_pad_impl(x, pad_width) + mesh = mesh_lib.thread_resources.env.physical_mesh + if distributed and not (mesh.empty) and (pad_width[0][0] > 0 + or pad_width[1][0] > 0): + return autoshmap((partial(jnp.pad, pad_width=pad_width)), + in_specs=(P('x', 'y')), + out_specs=P('x', 'y'))(x) else: return x - + + def slice_unpad(x, pad_width): - if distributed and not (mesh_lib.thread_resources.env.physical_mesh.empty): - return slice_unpad_impl(x, pad_width) + mesh = mesh_lib.thread_resources.env.physical_mesh + if distributed and not (mesh.empty) and (pad_width[0][0] > 0 + or pad_width[1][0] > 0): + return autoshmap(partial(slice_unpad_impl, pad_width=pad_width), + in_specs=(P('x', 'y')), + out_specs=P('x', 'y'))(x) else: return x diff --git a/jaxpm/pm.py b/jaxpm/pm.py index fe16450..20c251e 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -9,10 +9,10 @@ from jaxpm.distributed import autoshmap, fft3d, get_local_shape, ifft3d from jaxpm.growth import dGfa, growth_factor, growth_rate from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, longrange_kernel) -from jaxpm.painting import cic_paint, cic_read +from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx -def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): +def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): """ Computes gravitational forces on particles using a PM scheme """ @@ -21,7 +21,7 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): kvec = fftk(mesh_shape) if delta is None: - delta_k = fft3d(cic_paint(jnp.zeros(mesh_shape), positions)) + delta_k = fft3d(cic_paint_dx(positions, halo_size=0)) else: delta_k = fft3d(delta) @@ -29,26 +29,28 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0): pot_k = delta_k * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=r_split) # Computes gravitational forces - return jnp.stack([ - cic_read(ifft3d(gradient_kernel(kvec, i) * pot_k), positions) + forces = jnp.stack([ + cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), halo_size=0) for i in range(3) ], - axis=-1) + axis=-1) + return forces -def lpt(cosmo, initial_conditions, a, particles_shape=None): +def lpt(cosmo, initial_conditions, a, halo_size=0): """ Computes first order LPT displacement """ - if particles_shape is None: - particles_shape = initial_conditions.shape - local_mesh_shape = get_local_shape(particles_shape) + local_mesh_shape = get_local_shape(initial_conditions.shape) + (3, ) displacement = autoshmap( - partial(jnp.zeros, shape=local_mesh_shape+[3], dtype='float32'), + partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), in_specs=(), out_specs=P('x', 'y'))() # yapf: disable - initial_force = pm_forces(displacement, delta=initial_conditions) + + initial_force = pm_forces(displacement, + delta=initial_conditions, + halo_size=halo_size) a = jnp.atleast_1d(a) dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, @@ -80,7 +82,7 @@ def linear_field(mesh_shape, box_size, pk, seed): return field -def make_ode_fn(mesh_shape): +def make_ode_fn(mesh_shape, halo_size=0): def nbody_ode(state, a, cosmo): """ @@ -88,7 +90,8 @@ def make_ode_fn(mesh_shape): """ pos, vel = state - forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m + forces = pm_forces(pos, mesh_shape=mesh_shape, + halo_size=halo_size) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel