diff --git a/jaxpm/experimental/__init__.py b/jaxpm/experimental/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/jaxpm/experimental/distributed_ops.py b/jaxpm/experimental/distributed_ops.py deleted file mode 100644 index a06b03c..0000000 --- a/jaxpm/experimental/distributed_ops.py +++ /dev/null @@ -1,292 +0,0 @@ -from functools import partial - -import jax -import jax.lax as lax -import jax.numpy as jnp -import jax_cosmo as jc -from jax.experimental.maps import xmap -from jax.experimental.pjit import PartitionSpec, pjit - -import jaxpm.painting as paint - -# TODO: add a way to configure axis resources from command line -axis_resources = {'x': 'nx', 'y': 'ny'} -mesh_size = {'nx': 2, 'ny': 2} - - -@partial(xmap, - in_axes=({ - 0: 'x', - 2: 'y' - }, { - 0: 'x', - 2: 'y' - }, { - 0: 'x', - 2: 'y' - }), - out_axes=({ - 0: 'x', - 2: 'y' - }), - axis_resources=axis_resources) -def stack3d(a, b, c): - return jnp.stack([a, b, c], axis=-1) - - -@partial(xmap, - in_axes=({ - 0: 'x', - 2: 'y' - }, [...]), - out_axes=({ - 0: 'x', - 2: 'y' - }), - axis_resources=axis_resources) -def scalar_multiply(a, factor): - return a * factor - - -@partial(xmap, - in_axes=({ - 0: 'x', - 2: 'y' - }, { - 0: 'x', - 2: 'y' - }), - out_axes=({ - 0: 'x', - 2: 'y' - }), - axis_resources=axis_resources) -def add(a, b): - return a + b - - -@partial(xmap, - in_axes=['x', 'y', ...], - out_axes=['x', 'y', ...], - axis_resources=axis_resources) -def fft3d(mesh): - """ Performs a 3D complex Fourier transform - - Args: - mesh: a real 3D tensor of shape [Nx, Ny, Nz] - - Returns: - 3D FFT of the input, note that the dimensions of the output - are tranposed. - """ - mesh = jnp.fft.fft(mesh) - mesh = lax.all_to_all(mesh, 'x', 0, 0) - mesh = jnp.fft.fft(mesh) - mesh = lax.all_to_all(mesh, 'y', 0, 0) - return jnp.fft.fft(mesh) # Note the output is transposed # [z, x, y] - - -@partial(xmap, - in_axes=['x', 'y', ...], - out_axes=['x', 'y', ...], - axis_resources=axis_resources) -def ifft3d(mesh): - mesh = jnp.fft.ifft(mesh) - mesh = lax.all_to_all(mesh, 'y', 0, 0) - mesh = jnp.fft.ifft(mesh) - mesh = lax.all_to_all(mesh, 'x', 0, 0) - return jnp.fft.ifft(mesh).real - - -def normal(key, shape=[]): - - @partial(xmap, - in_axes=['x', 'y', ...], - out_axes={ - 0: 'x', - 2: 'y' - }, - axis_resources=axis_resources) - def fn(key): - """ Generate a distributed random normal distributions - Args: - key: array of random keys with same layout as computational mesh - shape: logical shape of array to sample - """ - return jax.random.normal( - key, - shape=[shape[0] // mesh_size['nx'], shape[1] // mesh_size['ny']] + - shape[2:]) - - return fn(key) - - -@partial(xmap, - in_axes=(['x', 'y', ...], [['x'], ['y'], [...]], [...], [...]), - out_axes=['x', 'y', ...], - axis_resources=axis_resources) -@jax.jit -def scale_by_power_spectrum(kfield, kvec, k, pk): - kx, ky, kz = kvec - kk = jnp.sqrt(kx**2 + ky**2 + kz**2) - return kfield * jc.scipy.interpolate.interp(kk, k, pk) - - -@partial(xmap, - in_axes=(['x', 'y', 'z'], [['x'], ['y'], ['z']]), - out_axes=(['x', 'y', 'z']), - axis_resources=axis_resources) -def gradient_laplace_kernel(kfield, kvec): - kx, ky, kz = kvec - kk = (kx**2 + ky**2 + kz**2) - kernel = jnp.where(kk == 0, 1., 1. / kk) - return (kfield * kernel * 1j * 1 / 6.0 * - (8 * jnp.sin(ky) - jnp.sin(2 * ky)), kfield * kernel * 1j * 1 / - 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), kfield * kernel * 1j * - 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))) - - -@partial(xmap, - in_axes=([...]), - out_axes={ - 0: 'x', - 2: 'y' - }, - axis_sizes={ - 'x': mesh_size['nx'], - 'y': mesh_size['ny'] - }, - axis_resources=axis_resources) -def meshgrid(x, y, z): - """ Generates a mesh grid of appropriate size for the - computational mesh we have. - """ - return jnp.stack(jnp.meshgrid(x, y, z), axis=-1) - - -def cic_paint(pos, mesh_shape, halo_size=0): - - @partial(xmap, - in_axes=({ - 0: 'x', - 2: 'y' - }), - out_axes=({ - 0: 'x', - 2: 'y' - }), - axis_resources=axis_resources) - def fn(pos): - - mesh = jnp.zeros([ - mesh_shape[0] // mesh_size['nx'] + - 2 * halo_size, mesh_shape[1] // mesh_size['ny'] + 2 * halo_size - ] + mesh_shape[2:]) - - # Paint particles - mesh = paint.cic_paint( - mesh, - pos.reshape(-1, 3) + - jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) - - # Perform halo exchange - # Halo exchange along x - left = lax.pshuffle(mesh[-2 * halo_size:], - perm=range(mesh_size['nx'])[::-1], - axis_name='x') - right = lax.pshuffle(mesh[:2 * halo_size], - perm=range(mesh_size['nx'])[::-1], - axis_name='x') - mesh = mesh.at[:2 * halo_size].add(left) - mesh = mesh.at[-2 * halo_size:].add(right) - - # Halo exchange along y - left = lax.pshuffle(mesh[:, -2 * halo_size:], - perm=range(mesh_size['ny'])[::-1], - axis_name='y') - right = lax.pshuffle(mesh[:, :2 * halo_size], - perm=range(mesh_size['ny'])[::-1], - axis_name='y') - mesh = mesh.at[:, :2 * halo_size].add(left) - mesh = mesh.at[:, -2 * halo_size:].add(right) - - # removing halo and returning mesh - return mesh[halo_size:-halo_size, halo_size:-halo_size] - - return fn(pos) - - -def cic_read(mesh, pos, halo_size=0): - - @partial(xmap, - in_axes=( - { - 0: 'x', - 2: 'y' - }, - { - 0: 'x', - 2: 'y' - }, - ), - out_axes=({ - 0: 'x', - 2: 'y' - }), - axis_resources=axis_resources) - def fn(mesh, pos): - - # Halo exchange to grab neighboring borders - # Exchange along x - left = lax.pshuffle(mesh[-halo_size:], - perm=range(mesh_size['nx'])[::-1], - axis_name='x') - right = lax.pshuffle(mesh[:halo_size], - perm=range(mesh_size['nx'])[::-1], - axis_name='x') - mesh = jnp.concatenate([left, mesh, right], axis=0) - # Exchange along y - left = lax.pshuffle(mesh[:, -halo_size:], - perm=range(mesh_size['ny'])[::-1], - axis_name='y') - right = lax.pshuffle(mesh[:, :halo_size], - perm=range(mesh_size['ny'])[::-1], - axis_name='y') - mesh = jnp.concatenate([left, mesh, right], axis=1) - - # Reading field at particles positions - res = paint.cic_read( - mesh, - pos.reshape(-1, 3) + - jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) - - return res.reshape(pos.shape[:-1]) - - return fn(mesh, pos) - - -@partial(pjit, - in_axis_resources=PartitionSpec('nx', 'ny'), - out_axis_resources=PartitionSpec('nx', None, 'ny', None)) -def reshape_dense_to_split(x): - """ Redistribute data from [x,y,z] convention to [Nx,x,Ny,y,z] - Changes the logical shape of the array, but no shuffling of the - data should be necessary - """ - shape = list(x.shape) - return x.reshape([ - mesh_size['nx'], shape[0] // - mesh_size['nx'], mesh_size['ny'], shape[2] // mesh_size['ny'] - ] + shape[2:]) - - -@partial(pjit, - in_axis_resources=PartitionSpec('nx', None, 'ny', None), - out_axis_resources=PartitionSpec('nx', 'ny')) -def reshape_split_to_dense(x): - """ Redistribute data from [Nx,x,Ny,y,z] convention to [x,y,z] - Changes the logical shape of the array, but no shuffling of the - data should be necessary - """ - shape = list(x.shape) - return x.reshape([shape[0] * shape[1], shape[2] * shape[3]] + shape[4:]) diff --git a/jaxpm/experimental/distributed_pm.py b/jaxpm/experimental/distributed_pm.py deleted file mode 100644 index b633cf7..0000000 --- a/jaxpm/experimental/distributed_pm.py +++ /dev/null @@ -1,100 +0,0 @@ -from functools import partial - -import jax -import jax.numpy as jnp -import jax_cosmo as jc -from jax.experimental.maps import xmap -from jax.lax import linear_solve_p - -import jaxpm.experimental.distributed_ops as dops -from jaxpm.growth import dGfa, growth_factor, growth_rate -from jaxpm.kernels import fftk - - -def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16): - """ - Computes gravitational forces on particles using a PM scheme - """ - if mesh_shape is None: - mesh_shape = delta_k.shape - kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)] - - if delta_k is None: - delta = dops.cic_paint(positions, mesh_shape, halo_size) - delta_k = dops.fft3d(dops.reshape_split_to_dense(delta)) - - forces_k = dops.gradient_laplace_kernel(delta_k, kvec) - - # Recovers forces at particle positions - forces = [ - dops.cic_read(dops.reshape_dense_to_split(dops.ifft3d(f)), positions, - halo_size) for f in forces_k - ] - - return dops.stack3d(*forces) - - -def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True): - """ - Generate initial conditions. - Seed should have the dimension of the computational mesh - """ - - # Sample normal field - field = dops.normal(seed, shape=mesh_shape) - - # Go to Fourier space - field = dops.fft3d(dops.reshape_split_to_dense(field)) - - # Rescaling k to physical units - kvec = [ - k.squeeze() / box_size[i] * mesh_shape[i] - for i, k in enumerate(fftk(mesh_shape, symmetric=False)) - ] - k = jnp.logspace(-4, 2, 256) - pk = jc.power.linear_matter_power(cosmo, k) - pk = pk * (mesh_shape[0] * mesh_shape[1] * - mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2]) - - field = dops.scale_by_power_spectrum(field, kvec, k, jnp.sqrt(pk)) - - if return_Fourier: - return field - else: - return dops.reshape_dense_to_split(dops.ifft3d(field)) - - -def lpt(cosmo, initial_conditions, positions, a): - """ - Computes first order LPT displacement - """ - initial_force = pm_forces(positions, delta_k=initial_conditions) - a = jnp.atleast_1d(a) - dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a)) - p = dops.scalar_multiply( - dx, - a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a))) - return dx, p - - -def make_ode_fn(mesh_shape): - - def nbody_ode(state, a, cosmo): - """ - state is a tuple (position, velocities) - """ - pos, vel = state - - forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m - - # Computes the update of position (drift) - dpos = dops.scalar_multiply( - vel, 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a)))) - - # Computes the update of velocity (kick) - dvel = dops.scalar_multiply( - forces, 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)))) - - return dpos, dvel - - return nbody_ode