From 72ae0fd88f8a299f6d2a627ecd1e3266267e2f7e Mon Sep 17 00:00:00 2001 From: EiffL Date: Sat, 22 Oct 2022 15:58:32 -0400 Subject: [PATCH] fixed a whole lot of issues --- jaxpm/kernels.py | 137 ++++++++++++++++++++++-------------------- jaxpm/ops.py | 52 +++++++++++----- jaxpm/painting.py | 19 +++--- jaxpm/pm.py | 120 +++++++++++++++++------------------- scripts/test_nbody.py | 78 ++++++++++++++++++++++++ 5 files changed, 251 insertions(+), 155 deletions(-) create mode 100644 scripts/test_nbody.py diff --git a/jaxpm/kernels.py b/jaxpm/kernels.py index 61b9e58..ddd909a 100644 --- a/jaxpm/kernels.py +++ b/jaxpm/kernels.py @@ -1,84 +1,91 @@ +import jax +from jax.experimental.maps import xmap import numpy as np import jax.numpy as jnp +from functools import partial -def fftk(shape, symmetric=True, dtype=np.float32, comms=None): - """ Return k_vector given a shape (nc, nc, nc) - """ - k = [] - if comms is not None: - nx = comms[0].Get_size() - ix = comms[0].Get_rank() - ny = comms[1].Get_size() - iy = comms[1].Get_rank() - shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:]) +def fftk(shape, symmetric=False, dtype=np.float32, comms=None): + """ Return k_vector given a shape (nc, nc, nc) + """ + k = [] - for d in range(len(shape)): - kd = np.fft.fftfreq(shape[d]) - kd *= 2 * np.pi + if comms is not None: + nx = comms[0].Get_size() + ix = comms[0].Get_rank() + ny = comms[1].Get_size() + iy = comms[1].Get_rank() + shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:]) - if symmetric and d == len(shape) - 1: - kd = kd[:shape[d] // 2 + 1] + for d in range(len(shape)): + kd = np.fft.fftfreq(shape[d]) + kd *= 2 * np.pi - if (comms is not None) and d==0: - kd = kd.reshape([nx, -1])[ix] + if symmetric and d == len(shape) - 1: + kd = kd[:shape[d] // 2 + 1] - if (comms is not None) and d==1: - kd = kd.reshape([ny, -1])[iy] + if (comms is not None) and d == 0: + kd = kd.reshape([nx, -1])[ix] - k.append(kd.astype(dtype)) - return k + if (comms is not None) and d == 1: + kd = kd.reshape([ny, -1])[iy] -@partial(jax.pmap, - in_axes=[['x','y','z'], - ['x'],['y'],['z']], - out_axes=['x','y','z',...]) + k.append(kd.astype(dtype)) + return k + + +@partial(xmap, + in_axes=[['x', 'y', ...], + [['x'], ['y'], [...]]], + out_axes=['x', 'y', ...]) def apply_gradient_laplace(kfield, kvec): kx, ky, kz = kvec kk = (kx**2 + ky**2 + kz**2) kernel = jnp.where(kk == 0, 1., 1./kk) return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), - kfield * kernel * 1j * 1 / 6.0 * - (8 * jnp.sin(kz) - jnp.sin(2 * kz)), - kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))],axis=-1) + kfield * kernel * 1j * 1 / 6.0 * + (8 * jnp.sin(kz) - jnp.sin(2 * kz)), + kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1) + def cic_compensation(kvec): - """ - Computes cic compensation kernel. - Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499 - Itself based on equation 18 (with p=2) of - `Jing et al 2005 `_ - Args: - kvec: array of k values in Fourier space - Returns: - v: array of kernel - """ - kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] - wts = (kwts[0] * kwts[1] * kwts[2])**(-2) - return wts + """ + Computes cic compensation kernel. + Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499 + Itself based on equation 18 (with p=2) of + `Jing et al 2005 `_ + Args: + kvec: array of k values in Fourier space + Returns: + v: array of kernel + """ + kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] + wts = (kwts[0] * kwts[1] * kwts[2])**(-2) + return wts + def PGD_kernel(kvec, kl, ks): - """ - Computes the PGD kernel - Parameters: - ----------- - kvec: array - Array of k values in Fourier space - kl: float - initial long range scale parameter - ks: float - initial dhort range scale parameter - Returns: - -------- - v: array - kernel - """ - kk = sum(ki**2 for ki in kvec) - kl2 = kl**2 - ks4 = ks**4 - mask = (kk == 0).nonzero() - kk[mask] = 1 - v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4) - imask = (~(kk == 0)).astype(int) - v *= imask - return v \ No newline at end of file + """ + Computes the PGD kernel + Parameters: + ----------- + kvec: array + Array of k values in Fourier space + kl: float + initial long range scale parameter + ks: float + initial dhort range scale parameter + Returns: + -------- + v: array + kernel + """ + kk = sum(ki**2 for ki in kvec) + kl2 = kl**2 + ks4 = ks**4 + mask = (kk == 0).nonzero() + kk[mask] = 1 + v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4) + imask = (~(kk == 0)).astype(int) + v *= imask + return v diff --git a/jaxpm/ops.py b/jaxpm/ops.py index fe6e387..66c9cec 100644 --- a/jaxpm/ops.py +++ b/jaxpm/ops.py @@ -73,34 +73,58 @@ def ifft3d(arr, comms=None): def halo_reduce(arr, halo_size, comms=None): + if halo_size <= 0: + return arr # Perform halo exchange along x rank_x = comms[0].Get_rank() + size_x = comms[0].Get_size() margin = arr[-2*halo_size:] - margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1, - comm=comms[0]) - arr = arr.at[:2*halo_size].add(margin) - + left, token = mpi4jax.sendrecv(margin, margin, + (rank_x-1) % size_x, + (rank_x+1) % size_x, + comm=comms[0]) margin = arr[:2*halo_size] - margin, token = mpi4jax.sendrecv(margin, margin, rank_x+1, rank_x-1, - comm=comms[0], token=token) - arr = arr.at[-2*halo_size:].add(margin) + right, token = mpi4jax.sendrecv(margin, margin, + (rank_x+1) % size_x, + (rank_x-1) % size_x, + comm=comms[0], token=token) + + arr = arr.at[:2*halo_size].add(left) + arr = arr.at[-2*halo_size:].add(right) # Perform halo exchange along y rank_y = comms[1].Get_rank() + size_y = comms[1].Get_size() margin = arr[:, -2*halo_size:] - margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1, - comm=comms[1], token=token) - arr = arr.at[:, :2*halo_size].add(margin) - + left, token = mpi4jax.sendrecv(margin, margin, + (rank_y-1) % size_y, + (rank_y+1) % size_y, + comm=comms[1], token=token) margin = arr[:, :2*halo_size] - margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1, - comm=comms[1], token=token) - arr = arr.at[:, -2*halo_size:].add(margin) + right, token = mpi4jax.sendrecv(margin, margin, + (rank_y+1) % size_y, + (rank_y-1) % size_y, + comm=comms[1], token=token) + arr = arr.at[:, :2*halo_size].add(left) + arr = arr.at[:, -2*halo_size:].add(right) return arr +def meshgrid3d(shape, comms=None): + if comms is not None: + nx = comms[0].Get_size() + ny = comms[1].Get_size() + + coords = [jnp.arange(shape[0]//nx), + jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]] + else: + coords = [jnp.arange(s) for s in shape[2:]] + + return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3]) + + def zeros(shape, comms=None): """ Initialize an array of given global shape partitionned if need be accross dimensions. diff --git a/jaxpm/painting.py b/jaxpm/painting.py index c9b72ca..78a3bd9 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -6,7 +6,7 @@ from jaxpm.ops import halo_reduce from jaxpm.kernels import fftk, cic_compensation -def cic_paint(mesh, positions, halo_size=0, token=None, comms=None): +def cic_paint(mesh, positions, halo_size=0, comms=None): """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] @@ -43,11 +43,11 @@ def cic_paint(mesh, positions, halo_size=0, token=None, comms=None): if comms == None: return mesh else: - mesh, token = halo_reduce(mesh, halo_size, token, comms) + mesh = halo_reduce(mesh, halo_size, comms) return mesh[halo_size:-halo_size, halo_size:-halo_size] -def cic_read(mesh, positions, halo_size=0, token=None, comms=None): +def cic_read(mesh, positions, halo_size=0, comms=None): """ Paints positions onto mesh mesh: [nx, ny, nz] positions: [npart, 3] @@ -59,7 +59,7 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None): mesh = jnp.pad(mesh, [[halo_size, halo_size], [halo_size, halo_size], [0, 0]]) - mesh, token = halo_reduce(mesh, halo_size, token, comms) + mesh = halo_reduce(mesh, halo_size, comms) positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) positions = jnp.expand_dims(positions, 1) @@ -75,14 +75,9 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None): neighboor_coords = jnp.mod( neighboor_coords.astype('int32'), jnp.array(mesh.shape)) - res = (mesh[neighboor_coords[..., 0], - neighboor_coords[..., 1], - neighboor_coords[..., 3]]*kernel).sum(axis=-1) - - if comms is not None: - return res - else: - return res, token + return (mesh[neighboor_coords[..., 0], + neighboor_coords[..., 1], + neighboor_coords[..., 3]]*kernel).sum(axis=-1) def cic_paint_2d(mesh, positions, weight): diff --git a/jaxpm/pm.py b/jaxpm/pm.py index 3fcc760..8ec910a 100644 --- a/jaxpm/pm.py +++ b/jaxpm/pm.py @@ -1,107 +1,99 @@ import jax +from jax.experimental.maps import xmap import jax.numpy as jnp import jax_cosmo as jc -from jaxpm.ops import fft3d, ifft3d, zeros +from jaxpm.ops import fft3d, ifft3d, zeros, normal from jaxpm.kernels import fftk, apply_gradient_laplace from jaxpm.painting import cic_paint, cic_read from jaxpm.growth import growth_factor, growth_rate, dGfa + def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None): """ Computes gravitational forces on particles using a PM scheme """ - if mesh_shape is None: - mesh_shape = delta_k.shape - - kvec = fftk(mesh_shape, comms=comms) - if delta_k is None: - delta, token = cic_paint(zeros(mesh_shape,comms=comms), - positions, - halo_size=halo_size, token=token, comms=comms) - delta_k, token = fft3d(delta, token=token, comms=comms) - - # Computes gravitational potential - forces_k = apply_gradient_laplace(kfield, kvec) + delta = cic_paint(zeros(mesh_shape, comms=comms), + positions, + halo_size=halo_size, comms=comms) + delta_k = fft3d(delta, comms=comms) # Computes gravitational forces - fx, token = ifft3d(forces_k[...,0], token=token, comms=comms) - fx, token = cic_read(fx, positions, halo_size=halo_size, comms=comms) + kvec = fftk(delta_k.shape, symmetric=False, comms=comms) + forces_k = apply_gradient_laplace(delta_k, kvec) - fy, token = ifft3d(forces_k[...,1], token=token, comms=comms) - fy, token = cic_read(fy, positions, halo_size=halo_size, comms=comms) + # Interpolate forces at the position of particles + return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real, + positions, halo_size=halo_size, comms=comms) + for i in range(3)], axis=-1) - fz, token = ifft3d(forces_k[...,2], token=token, comms=comms) - fz, token = cic_read(fz, positions, halo_size=halo_size, comms=comms) - return jnp.stack([fx,fy,fz],axis=-1), token - -def lpt(cosmo, initial_conditions, positions, a, token=token, comms=comms): +def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None): """ Computes first order LPT displacement """ - initial_force = pm_forces(positions, delta=initial_conditions, token=token, comms=comms) + initial_force = pm_forces( + positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms) a = jnp.atleast_1d(a) dx = growth_factor(cosmo, a) * initial_force - p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx - f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force - return dx, p, f, comms + p = a**2 * growth_rate(cosmo, a) * \ + jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx + f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \ + dGfa(cosmo, a) * initial_force + return dx, p, f -def linear_field(mesh_shape, box_size, pk, seed): + +def linear_field(cosmo, mesh_shape, box_size, key, comms=None): """ - Generate initial conditions. + Generate initial conditions in Fourier space. """ - kvec = fftk(mesh_shape) - kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 - pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2]) + # Sample normal field + field = normal(key, mesh_shape, comms=comms) - field = jax.random.normal(seed, mesh_shape) - field = jnp.fft.rfftn(field) * pkmesh**0.5 - field = jnp.fft.irfftn(field) - return field + # Transform to Fourier space + kfield = fft3d(field, comms=comms) + + # Rescaling k to physical units + kvec = [k / box_size[i] * mesh_shape[i] + for i, k in enumerate(fftk(kfield.shape, + symmetric=False, + comms=comms))] + + # Evaluating linear matter powerspectrum + k = jnp.logspace(-4, 2, 256) + pk = jc.power.linear_matter_power(cosmo, k) + pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2] + ) / (box_size[0] * box_size[1] * box_size[2]) + + # Multipliyng the field by the proper power spectrum + kfield = xmap(lambda kfield, kx, ky, kz: + kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2), + k, jnp.sqrt(pk)), + in_axes=(('x', 'y', ...), ['x'], ['y'], [...]), + out_axes=('x', 'y', ...))(kfield, kvec[0], kvec[1], kvec[2]) + + return kfield + + +def make_ode_fn(mesh_shape, halo_size=0, comms=None): -def make_ode_fn(mesh_shape): - def nbody_ode(state, a, cosmo): """ state is a tuple (position, velocities) """ pos, vel = state - forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m + forces = pm_forces(pos, mesh_shape=mesh_shape, + halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m # Computes the update of position (drift) dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel - + # Computes the update of velocity (kick) dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces - + return dpos, dvel return nbody_ode - - -def pgd_correction(pos, params): - """ - improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671 - args: - pos: particle positions [npart, 3] - params: [alpha, kl, ks] pgd parameters - """ - kvec = fftk(mesh_shape) - - delta = cic_paint(jnp.zeros(mesh_shape), pos) - alpha, kl, ks = params - delta_k = jnp.fft.rfftn(delta) - PGD_range=PGD_kernel(kvec, kl, ks) - - pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range - - forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos) - for i in range(3)],axis=-1) - - dpos_pgd = forces_pgd*alpha - - return dpos_pgd \ No newline at end of file diff --git a/scripts/test_nbody.py b/scripts/test_nbody.py new file mode 100644 index 0000000..b9860f2 --- /dev/null +++ b/scripts/test_nbody.py @@ -0,0 +1,78 @@ +from dataclasses import fields +from mpi4py import MPI +import jax +import jax.numpy as jnp +import numpy as onp +import mpi4jax +from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros +from jaxpm.pm import linear_field, lpt, make_ode_fn +from jaxpm.painting import cic_paint +from jax.experimental.ode import odeint +import jax_cosmo as jc + + +### Setting up a whole bunch of things ####### +# Create communicators +world = MPI.COMM_WORLD +rank = world.Get_rank() +size = world.Get_size() + +cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2], + periods=[True, True]) +comms = [cart_comm.Sub([True, False]), + cart_comm.Sub([False, True])] + +# Setup random keys +master_key = jax.random.PRNGKey(42) +key = jax.random.split(master_key, size)[rank] +################################################ + +# Size and parameters of the simulation volume +N = 256 +mesh_shape = [N, N, N] +box_size = [205, 205, 205] # Mpc/h +cosmo = jc.Planck15() +halo_size = 16 +a = 0.1 + + +@jax.jit +def run_sim(cosmo, key): + initial_conditions = linear_field(cosmo, mesh_shape, box_size, key, + comms=comms) + init_field = ifft3d(initial_conditions, comms=comms).real + + # Initialize particles + pos = meshgrid3d(mesh_shape, comms=comms) + + # Initial displacement by LPT + cosmo = jc.Planck15() + dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms) + + # And now, we run an actual nbody + res = odeint(make_ode_fn(mesh_shape, halo_size, comms), + [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo, + rtol=1e-5, atol=1e-5) + + # Painting on a new mesh + field = cic_paint(zeros(mesh_shape, comms=comms), + res[0][-1], halo_size, comms=comms) + + return init_field, field + + +# Recover the real space initial conditions +init_field, field = run_sim(cosmo, key) + +# Testing that the result is actually looking like what we expect +total_array, token = mpi4jax.allgather(field, comm=comms[0]) +total_array = total_array.reshape([N, N//2, N]) +total_array, token = mpi4jax.allgather( + total_array.transpose([1, 0, 2]), comm=comms[1], token=token) +total_array = total_array.reshape([N, N, N]) +total_array = total_array.transpose([1, 0, 2]) + +if rank == 0: + onp.save('simulation.npy', total_array) + +print('Done !')