From 137f4e50997bfcd3563b750a43cfda0fdfabb6b9 Mon Sep 17 00:00:00 2001 From: EiffL Date: Thu, 20 Oct 2022 23:05:39 -0400 Subject: [PATCH] adding distributed ops --- jaxpm/distributed_ops.py | 223 +++++++++++++++++++++++++++++++++++++++ jaxpm/distributed_pm.py | 94 +++++++++++++++++ 2 files changed, 317 insertions(+) create mode 100644 jaxpm/distributed_ops.py create mode 100644 jaxpm/distributed_pm.py diff --git a/jaxpm/distributed_ops.py b/jaxpm/distributed_ops.py new file mode 100644 index 0000000..004c83a --- /dev/null +++ b/jaxpm/distributed_ops.py @@ -0,0 +1,223 @@ +import jax +import jax.numpy as jnp +import jax.lax as lax +from functools import partial +from jax.experimental.maps import xmap +from jax.experimental.pjit import pjit, PartitionSpec + +import jax_cosmo as jc +import jaxpm as jpm + +# TODO: add a way to configure axis resources from command line +axis_resources = {'x': 'nx', 'y': 'ny'} +mesh_size = {'nx': 2, 'ny': 2} + + +@partial(xmap, + in_axes=({0: 'x', 2: 'y'}, + {0: 'x', 2: 'y'}, + {0: 'x', 2: 'y'}), + out_axes=({0: 'x', 2: 'y'}), + axis_resources=axis_resources) +def stack3d(a, b, c): + return jnp.stack([a, b, c], axis=-1) + + +@partial(xmap, + in_axes=({0: 'x', 2: 'y'}), + out_axes=({0: 'x', 2: 'y'}), + axis_resources=axis_resources) +def scalar_multiply(a, factor): + return a * factor + + +@partial(xmap, + in_axes=({0: 'x', 2: 'y'}, + {0: 'x', 2: 'y'}), + out_axes=({0: 'x', 2: 'y'}), + axis_resources=axis_resources) +def add(a, b): + return a + b + + +@partial(xmap, + in_axes=['x', 'y'], + out_axes=['x', 'y'], + axis_resources=axis_resources) +def fft3d(mesh): + """ Performs a 3D complex Fourier transform + + Args: + mesh: a real 3D tensor of shape [Nx, Ny, Nz] + + Returns: + 3D FFT of the input, note that the dimensions of the output + are tranposed. + """ + mesh = jnp.fft.fft(mesh) + mesh = lax.all_to_all(mesh, 'x', 0, 0) + mesh = jnp.fft.fft(mesh) + mesh = lax.all_to_all(mesh, 'y', 0, 0) + return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y] + + +@partial(xmap, + in_axes=['x', 'y'], + out_axes=['x', 'y'], + axis_resources=axis_resources) +def ifft3d(mesh): + mesh = jnp.fft.ifft(mesh) + mesh = lax.all_to_all(mesh, 'y', 0, 0) + mesh = jnp.fft.ifft(mesh) + mesh = lax.all_to_all(mesh, 'x', 0, 0) + return jnp.fft.ifft(mesh).real + + +@partial(xmap, + in_axes=['x', 'y'], + out_axes={0: 'x', 2: 'y'}, + axis_resources=axis_resources) +def normal(key, shape): + """ Generate a distributed random normal distributions + Args: + key: array of random keys with same layout as computational mesh + shape: logical shape of array to sample + """ + return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'], + shape[1]//mesh_size['ny']]+shape[2:]) + + +@partial(xmap, + in_axes=(['x', 'y', ...], + [['x'], ['y'], ...]), + out_axes=['x', 'y', ...], + axis_resources=axis_resources) +@jax.jit +def scale_by_power_spectrum(kfield, kvec, k, pk): + kx, ky, kz = kvec + kk = jnp.sqrt(kx**2 + ky ** 2 + kz**2) + return kfield * jc.scipy.interpolate.interp(kk, k, pk) + + +@partial(xmap, + in_axes=(['x', 'y', 'z'], + [['x'], ['y'], ['z']]), + out_axes=(['x', 'y', 'z']), + axis_resources=axis_resources) +def gradient_laplace_kernel(kfield, kvec): + kx, ky, kz = kvec + kk = (kx**2 + ky**2 + kz**2) + kernel = jnp.where(kk == 0, 1., 1./kk) + return (kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), + kfield * kernel * 1j * 1 / 6.0 * + (8 * jnp.sin(kz) - jnp.sin(2 * kz)), + kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))) + + +@partial(xmap, + in_axes=([...]), + out_axes={0: 'x', 2: 'y'}, + axis_sizes={'x': mesh_size['nx'], + 'y': mesh_size['ny']}, + axis_resources=axis_resources) +def meshgrid(x, y, z): + """ Generates a mesh grid of appropriate size for the + computational mesh we have. + """ + return jnp.stack(jnp.meshgrid(x, y, z), axis=-1) + + +@partial(xmap, + in_axes=({0: 'x', 2: 'y'}), + out_axes=({0: 'x', 2: 'y'}), + axis_resources=axis_resources) +def cic_paint(pos, mesh_shape, halo_size=0): + + mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size, + mesh_shape[1]//mesh_size['ny']+2*halo_size] + + mesh_shape[2:]) + + # Paint particles + mesh = jpm.cic_paint(mesh, pos.reshape(-1, 3) + + jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) + + # Perform halo exchange + # Halo exchange along x + left = lax.pshuffle(mesh[-halo_size:], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + right = lax.pshuffle(mesh[:halo_size], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + mesh = mesh.at[:halo_size].add(left) + mesh = mesh.at[-halo_size:].add(right) + + # Halo exchange along y + left = lax.pshuffle(mesh[:, -halo_size:], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + right = lax.pshuffle(mesh[:, :halo_size], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + mesh = mesh.at[:, :halo_size].add(left) + mesh = mesh.at[:, -halo_size:].add(right) + + # removing halo and returning mesh + return mesh[halo_size:-halo_size, halo_size:-halo_size] + + +@partial(xmap, + in_axes=({0: 'x', 2: 'y'}, + {0: 'x', 2: 'y'}), + out_axes=({0: 'x', 2: 'y'}), + axis_resources=axis_resources) +def cic_read(mesh, pos, halo_size): + + # Halo exchange to grab neighboring borders + # Exchange along x + left = lax.pshuffle(mesh[-halo_size:], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + right = lax.pshuffle(mesh[:halo_size], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + mesh = jnp.concatenate([left, mesh, right], axis=0) + # Exchange along y + left = lax.pshuffle(mesh[:, -halo_size:], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + right = lax.pshuffle(mesh[:, :halo_size], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + mesh = jnp.concatenate([left, mesh, right], axis=1) + + # Reading field at particles positions + res = jpm.painting.cic_read(mesh, pos.reshape(-1, 3) + + jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) + + return res + + +@partial(pjit, + in_axis_resources=PartitionSpec('nx', 'ny'), + out_axis_resources=PartitionSpec('nx', None, 'ny', None)) +def reshape_dense_to_split(x): + """ Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z] + Changes the logical shape of the array, but no shuffling of the + data should be necessary + """ + shape = list(x.shape) + return x.reshape([mesh_size['nx'], shape[0]//mesh_size['nx'], + mesh_size['ny'], shape[2]//mesh_size['ny']] + shape[2:]) + + +@partial(pjit, + in_axis_resources=PartitionSpec('nx', None, 'ny', None), + out_axis_resources=PartitionSpec('nx', 'ny')) +def reshape_split_to_dense(x): + """ Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z] + Changes the logical shape of the array, but no shuffling of the + data should be necessary + """ + shape = list(x.shape) + return x.reshape([shape[0]*shape[1], shape[2]*shape[3]] + shape[4:]) diff --git a/jaxpm/distributed_pm.py b/jaxpm/distributed_pm.py new file mode 100644 index 0000000..c943152 --- /dev/null +++ b/jaxpm/distributed_pm.py @@ -0,0 +1,94 @@ +import jax +from jax.lax import linear_solve_p +import jax.numpy as jnp +from jax.experimental.maps import xmap +from functools import partial +import jax_cosmo as jc + +from jaxpm.kernels import fftk +import jaxpm.distributed_ops as dops +from jaxpm.growth import growth_factor, growth_rate, dGfa + + +def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16): + """ + Computes gravitational forces on particles using a PM scheme + """ + if mesh_shape is None: + mesh_shape = delta_k.shape + kvec = [k.squeeze() for k in fftk(mesh_shape)] + + if delta_k is None: + delta = dops.cic_paint(positions, mesh_shape, halo_size) + delta_k = dops.fft3d(dops.reshape_split_to_dense(delta)) + + forces_k = dops.gradient_laplace_kernel(delta_k, kvec) + + # Recovers forces at particle positions + forces = [dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), + positions, halo_size) for f in forces_k] + + return dops.stack3d(*forces) + + +def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True): + """ + Generate initial conditions. + Seed should have the dimension of the computational mesh + """ + + # Sample normal field + field = dops.normal(seed, mesh_shape) + + # Go to Fourier space + field = dops.fft3d(dops.reshape_split_to_dense(field)) + + # Rescaling k to physical units + kvec = [k.squeeze() / box_size[i] * mesh_shape[i] + for i, k in enumerate(fftk(mesh_shape, symmetric=False))] + k = jnp.logspace(-4, 2, 256) + pk = jc.power.linear_matter_power(cosmo, k) + pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2] + ) / (box_size[0] * box_size[1] * box_size[2]) + + field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk)) + + if return_Fourier: + return field + else: + return dops.reshape_dense_to_split(dops.ifft3d(field)) + + +def lpt(cosmo, initial_conditions, positions, a): + """ + Computes first order LPT displacement + """ + initial_force = pm_forces(positions, delta_k=initial_conditions) + a = jnp.atleast_1d(a) + dx = dops.scalar_multiply(initial_force * growth_factor(cosmo, a)) + p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) * + jnp.sqrt(jc.background.Esqr(cosmo, a))) + return dx, p + + +def make_ode_fn(mesh_shape): + + def nbody_ode(state, a, cosmo): + """ + state is a tuple (position, velocities) + """ + pos, vel = state + + forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m + + # Computes the update of position (drift) + dpos = dops.scalar_multiply( + vel, 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a)))) + + # Computes the update of velocity (kick) + dvel = dops.scalar_multiply( + forces, 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)))) + + return dpos, dvel + + return nbody_ode