diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 090bd2b..f78a765 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -14,17 +14,23 @@ from jaxpm.kernels import (PGD_kernel, fftk, gradient_kernel, laplace_kernel, from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx -def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): +def pm_forces(positions, + mesh_shape=None, + delta=None, + r_split=0, + halo_size=0, + sharding=None): """ Computes gravitational forces on particles using a PM scheme """ if mesh_shape is None: - assert (delta is not None - ), "If mesh_shape is not provided, delta should be provided" + assert (delta is not None),\ + "If mesh_shape is not provided, delta should be provided" mesh_shape = delta.shape if delta is None: - delta_k = fft3d(cic_paint_dx(positions, halo_size=halo_size)) + delta_k = fft3d( + cic_paint_dx(positions, halo_size=halo_size, sharding=sharding)) else: delta_k = fft3d(delta) @@ -35,7 +41,8 @@ def pm_forces(positions, mesh_shape=None, delta=None, r_split=0, halo_size=0): # Computes gravitational forces forces = jnp.stack([ cic_read_dx(ifft3d(gradient_kernel(kvec, i) * pot_k), - halo_size=halo_size) for i in range(3) + halo_size=halo_size, + sharding=sharding) for i in range(3) ], axis=-1) @@ -77,20 +84,25 @@ def lpt2_source(mesh_size, initial_conditions): return source -def lpt(cosmo, initial_conditions, a, halo_size=0): +def lpt(cosmo, initial_conditions, a, halo_size=0, sharding=None): """ Computes first order LPT displacement """ - local_mesh_shape = (*get_local_shape(initial_conditions.shape), 3) + gpu_mesh = sharding.mesh if sharding is not None else None + spec = sharding.spec if sharding is not None else P() + local_mesh_shape = (*get_local_shape(initial_conditions.shape, sharding), + 3) displacement = autoshmap( partial(jnp.zeros, shape=(local_mesh_shape), dtype='float32'), + gpu_mesh=gpu_mesh, in_specs=(), - out_specs=P('x', 'y'))() # yapf: disable + out_specs=spec)() # yapf: disable initial_force = pm_forces(displacement, delta=initial_conditions, - halo_size=halo_size) + halo_size=halo_size, + sharding=sharding) a = jnp.atleast_1d(a) dx = growth_factor(cosmo, a) * initial_force p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, @@ -133,12 +145,12 @@ def lpt2(cosmo, initial_conditions, dx, p, f, a, halo_size=0): return dx, p, f -def linear_field(mesh_shape, box_size, pk, seed): +def linear_field(mesh_shape, box_size, pk, seed, sharding=None): """ Generate initial conditions. """ # Initialize a random field with one slice on each gpu - field = normal_field(mesh_shape, seed=seed) + field = normal_field(mesh_shape, seed=seed, sharding=sharding) field = fft3d(field) kvec = fftk(field) kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 @@ -151,7 +163,7 @@ def linear_field(mesh_shape, box_size, pk, seed): return field -def make_ode_fn(mesh_shape, halo_size=0): +def make_ode_fn(mesh_shape, halo_size=0, sharding=None): def nbody_ode(state, a, cosmo): """ @@ -159,8 +171,9 @@ def make_ode_fn(mesh_shape, halo_size=0): """ pos, vel = state - forces = pm_forces(pos, mesh_shape=mesh_shape, - halo_size=halo_size) * 1.5 * cosmo.Omega_m + forces = pm_forces( + pos, mesh_shape=mesh_shape, halo_size=halo_size, + sharding=sharding) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel