From ebab27b4c3ca2a51b70ce4ae6499a8c3779eeeca Mon Sep 17 00:00:00 2001 From: EiffL Date: Thu, 20 Oct 2022 21:03:15 -0700 Subject: [PATCH] update data --- jaxpm/distributed_ops.py | 166 ++++++++++++++++++++------------------- jaxpm/distributed_pm.py | 7 +- 2 files changed, 91 insertions(+), 82 deletions(-) diff --git a/jaxpm/distributed_ops.py b/jaxpm/distributed_ops.py index 004c83a..9adcfc9 100644 --- a/jaxpm/distributed_ops.py +++ b/jaxpm/distributed_ops.py @@ -6,7 +6,7 @@ from jax.experimental.maps import xmap from jax.experimental.pjit import pjit, PartitionSpec import jax_cosmo as jc -import jaxpm as jpm +import jaxpm.painting as paint # TODO: add a way to configure axis resources from command line axis_resources = {'x': 'nx', 'y': 'ny'} @@ -24,7 +24,7 @@ def stack3d(a, b, c): @partial(xmap, - in_axes=({0: 'x', 2: 'y'}), + in_axes=({0: 'x', 2: 'y'},[...]), out_axes=({0: 'x', 2: 'y'}), axis_resources=axis_resources) def scalar_multiply(a, factor): @@ -41,8 +41,8 @@ def add(a, b): @partial(xmap, - in_axes=['x', 'y'], - out_axes=['x', 'y'], + in_axes=['x', 'y',...], + out_axes=['x', 'y',...], axis_resources=axis_resources) def fft3d(mesh): """ Performs a 3D complex Fourier transform @@ -62,8 +62,8 @@ def fft3d(mesh): @partial(xmap, - in_axes=['x', 'y'], - out_axes=['x', 'y'], + in_axes=['x', 'y',...], + out_axes=['x', 'y',...], axis_resources=axis_resources) def ifft3d(mesh): mesh = jnp.fft.ifft(mesh) @@ -72,24 +72,25 @@ def ifft3d(mesh): mesh = lax.all_to_all(mesh, 'x', 0, 0) return jnp.fft.ifft(mesh).real - -@partial(xmap, - in_axes=['x', 'y'], - out_axes={0: 'x', 2: 'y'}, - axis_resources=axis_resources) -def normal(key, shape): - """ Generate a distributed random normal distributions - Args: - key: array of random keys with same layout as computational mesh - shape: logical shape of array to sample - """ - return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'], - shape[1]//mesh_size['ny']]+shape[2:]) +def normal(key, shape=[]): + @partial(xmap, + in_axes=['x', 'y',...], + out_axes={0: 'x', 2: 'y'}, + axis_resources=axis_resources) + def fn(key): + """ Generate a distributed random normal distributions + Args: + key: array of random keys with same layout as computational mesh + shape: logical shape of array to sample + """ + return jax.random.normal(key, shape=[shape[0]//mesh_size['nx'], + shape[1]//mesh_size['ny']]+shape[2:]) + return fn(key) @partial(xmap, in_axes=(['x', 'y', ...], - [['x'], ['y'], ...]), + [['x'], ['y'], [...]], [...], [...]), out_axes=['x', 'y', ...], axis_resources=axis_resources) @jax.jit @@ -124,78 +125,85 @@ def meshgrid(x, y, z): """ Generates a mesh grid of appropriate size for the computational mesh we have. """ - return jnp.stack(jnp.meshgrid(x, y, z), axis=-1) + return jnp.stack(jnp.meshgrid(x, + y, + z), axis=-1) -@partial(xmap, - in_axes=({0: 'x', 2: 'y'}), - out_axes=({0: 'x', 2: 'y'}), - axis_resources=axis_resources) def cic_paint(pos, mesh_shape, halo_size=0): + @partial(xmap, + in_axes=({0: 'x', 2: 'y'}), + out_axes=({0: 'x', 2: 'y'}), + axis_resources=axis_resources) + def fn(pos): - mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size, - mesh_shape[1]//mesh_size['ny']+2*halo_size] - + mesh_shape[2:]) + mesh = jnp.zeros([mesh_shape[0]//mesh_size['nx']+2*halo_size, + mesh_shape[1]//mesh_size['ny']+2*halo_size] + + mesh_shape[2:]) - # Paint particles - mesh = jpm.cic_paint(mesh, pos.reshape(-1, 3) + - jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) + # Paint particles + mesh = paint.cic_paint(mesh, pos.reshape(-1, 3) + + jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) - # Perform halo exchange - # Halo exchange along x - left = lax.pshuffle(mesh[-halo_size:], - perm=range(mesh_size['nx'])[::-1], - axis_name='x') - right = lax.pshuffle(mesh[:halo_size], - perm=range(mesh_size['nx'])[::-1], - axis_name='x') - mesh = mesh.at[:halo_size].add(left) - mesh = mesh.at[-halo_size:].add(right) + # Perform halo exchange + # Halo exchange along x + left = lax.pshuffle(mesh[-halo_size:], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + right = lax.pshuffle(mesh[:halo_size], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + mesh = mesh.at[:halo_size].add(left) + mesh = mesh.at[-halo_size:].add(right) - # Halo exchange along y - left = lax.pshuffle(mesh[:, -halo_size:], - perm=range(mesh_size['ny'])[::-1], - axis_name='y') - right = lax.pshuffle(mesh[:, :halo_size], - perm=range(mesh_size['ny'])[::-1], - axis_name='y') - mesh = mesh.at[:, :halo_size].add(left) - mesh = mesh.at[:, -halo_size:].add(right) + # Halo exchange along y + left = lax.pshuffle(mesh[:, -halo_size:], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + right = lax.pshuffle(mesh[:, :halo_size], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + mesh = mesh.at[:, :halo_size].add(left) + mesh = mesh.at[:, -halo_size:].add(right) - # removing halo and returning mesh - return mesh[halo_size:-halo_size, halo_size:-halo_size] + # removing halo and returning mesh + return mesh[halo_size:-halo_size, halo_size:-halo_size] + return fn(pos) -@partial(xmap, - in_axes=({0: 'x', 2: 'y'}, - {0: 'x', 2: 'y'}), - out_axes=({0: 'x', 2: 'y'}), - axis_resources=axis_resources) -def cic_read(mesh, pos, halo_size): +def cic_read(mesh, pos, halo_size=0): + @partial(xmap, + in_axes=({0: 'x', 2: 'y'}, + {0: 'x', 2: 'y'},), + out_axes=({0: 'x', 2: 'y'}), + axis_resources=axis_resources) + def fn(mesh, pos): - # Halo exchange to grab neighboring borders - # Exchange along x - left = lax.pshuffle(mesh[-halo_size:], - perm=range(mesh_size['nx'])[::-1], - axis_name='x') - right = lax.pshuffle(mesh[:halo_size], - perm=range(mesh_size['nx'])[::-1], - axis_name='x') - mesh = jnp.concatenate([left, mesh, right], axis=0) - # Exchange along y - left = lax.pshuffle(mesh[:, -halo_size:], - perm=range(mesh_size['ny'])[::-1], - axis_name='y') - right = lax.pshuffle(mesh[:, :halo_size], - perm=range(mesh_size['ny'])[::-1], - axis_name='y') - mesh = jnp.concatenate([left, mesh, right], axis=1) + # Halo exchange to grab neighboring borders + # Exchange along x + left = lax.pshuffle(mesh[-halo_size:], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + right = lax.pshuffle(mesh[:halo_size], + perm=range(mesh_size['nx'])[::-1], + axis_name='x') + mesh = jnp.concatenate([left, mesh, right], axis=0) + # Exchange along y + left = lax.pshuffle(mesh[:, -halo_size:], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + right = lax.pshuffle(mesh[:, :halo_size], + perm=range(mesh_size['ny'])[::-1], + axis_name='y') + mesh = jnp.concatenate([left, mesh, right], axis=1) - # Reading field at particles positions - res = jpm.painting.cic_read(mesh, pos.reshape(-1, 3) + - jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) + # Reading field at particles positions + res = paint.cic_read(mesh, pos.reshape(-1, 3) + + jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])) - return res + return res.reshape(pos.shape[:-1]) + + return fn(mesh, pos) @partial(pjit, diff --git a/jaxpm/distributed_pm.py b/jaxpm/distributed_pm.py index c943152..7944fd2 100644 --- a/jaxpm/distributed_pm.py +++ b/jaxpm/distributed_pm.py @@ -16,7 +16,7 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=16): """ if mesh_shape is None: mesh_shape = delta_k.shape - kvec = [k.squeeze() for k in fftk(mesh_shape)] + kvec = [k.squeeze() for k in fftk(mesh_shape, symmetric=False)] if delta_k is None: delta = dops.cic_paint(positions, mesh_shape, halo_size) @@ -38,7 +38,7 @@ def linear_field(cosmo, mesh_shape, box_size, seed, return_Fourier=True): """ # Sample normal field - field = dops.normal(seed, mesh_shape) + field = dops.normal(seed, shape=mesh_shape) # Go to Fourier space field = dops.fft3d(dops.reshape_split_to_dense(field)) @@ -64,8 +64,9 @@ def lpt(cosmo, initial_conditions, positions, a): Computes first order LPT displacement """ initial_force = pm_forces(positions, delta_k=initial_conditions) + print(initial_force.shape) a = jnp.atleast_1d(a) - dx = dops.scalar_multiply(initial_force * growth_factor(cosmo, a)) + dx = dops.scalar_multiply(initial_force, growth_factor(cosmo, a)) p = dops.scalar_multiply(dx, a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a))) return dx, p