fixed a whole lot of issues

This commit is contained in:
EiffL 2022-10-22 15:58:32 -04:00
parent 429813ad92
commit 72ae0fd88f
5 changed files with 251 additions and 155 deletions

View file

@ -1,84 +1,91 @@
import jax
from jax.experimental.maps import xmap
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
from functools import partial
def fftk(shape, symmetric=True, dtype=np.float32, comms=None):
""" Return k_vector given a shape (nc, nc, nc)
"""
k = []
if comms is not None: def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
nx = comms[0].Get_size() """ Return k_vector given a shape (nc, nc, nc)
ix = comms[0].Get_rank() """
ny = comms[1].Get_size() k = []
iy = comms[1].Get_rank()
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
for d in range(len(shape)): if comms is not None:
kd = np.fft.fftfreq(shape[d]) nx = comms[0].Get_size()
kd *= 2 * np.pi ix = comms[0].Get_rank()
ny = comms[1].Get_size()
iy = comms[1].Get_rank()
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
if symmetric and d == len(shape) - 1: for d in range(len(shape)):
kd = kd[:shape[d] // 2 + 1] kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
if (comms is not None) and d==0: if symmetric and d == len(shape) - 1:
kd = kd.reshape([nx, -1])[ix] kd = kd[:shape[d] // 2 + 1]
if (comms is not None) and d==1: if (comms is not None) and d == 0:
kd = kd.reshape([ny, -1])[iy] kd = kd.reshape([nx, -1])[ix]
k.append(kd.astype(dtype)) if (comms is not None) and d == 1:
return k kd = kd.reshape([ny, -1])[iy]
@partial(jax.pmap, k.append(kd.astype(dtype))
in_axes=[['x','y','z'], return k
['x'],['y'],['z']],
out_axes=['x','y','z',...])
@partial(xmap,
in_axes=[['x', 'y', ...],
[['x'], ['y'], [...]]],
out_axes=['x', 'y', ...])
def apply_gradient_laplace(kfield, kvec): def apply_gradient_laplace(kfield, kvec):
kx, ky, kz = kvec kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2) kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk) kernel = jnp.where(kk == 0, 1., 1./kk)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)), return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
kfield * kernel * 1j * 1 / 6.0 * kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz)), (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))],axis=-1) kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1)
def cic_compensation(kvec): def cic_compensation(kvec):
""" """
Computes cic compensation kernel. Computes cic compensation kernel.
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499 Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
Itself based on equation 18 (with p=2) of Itself based on equation 18 (with p=2) of
`Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_ `Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
Args: Args:
kvec: array of k values in Fourier space kvec: array of k values in Fourier space
Returns: Returns:
v: array of kernel v: array of kernel
""" """
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)] kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2) wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts return wts
def PGD_kernel(kvec, kl, ks): def PGD_kernel(kvec, kl, ks):
""" """
Computes the PGD kernel Computes the PGD kernel
Parameters: Parameters:
----------- -----------
kvec: array kvec: array
Array of k values in Fourier space Array of k values in Fourier space
kl: float kl: float
initial long range scale parameter initial long range scale parameter
ks: float ks: float
initial dhort range scale parameter initial dhort range scale parameter
Returns: Returns:
-------- --------
v: array v: array
kernel kernel
""" """
kk = sum(ki**2 for ki in kvec) kk = sum(ki**2 for ki in kvec)
kl2 = kl**2 kl2 = kl**2
ks4 = ks**4 ks4 = ks**4
mask = (kk == 0).nonzero() mask = (kk == 0).nonzero()
kk[mask] = 1 kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4) v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int) imask = (~(kk == 0)).astype(int)
v *= imask v *= imask
return v return v

View file

@ -73,34 +73,58 @@ def ifft3d(arr, comms=None):
def halo_reduce(arr, halo_size, comms=None): def halo_reduce(arr, halo_size, comms=None):
if halo_size <= 0:
return arr
# Perform halo exchange along x # Perform halo exchange along x
rank_x = comms[0].Get_rank() rank_x = comms[0].Get_rank()
size_x = comms[0].Get_size()
margin = arr[-2*halo_size:] margin = arr[-2*halo_size:]
margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1, left, token = mpi4jax.sendrecv(margin, margin,
comm=comms[0]) (rank_x-1) % size_x,
arr = arr.at[:2*halo_size].add(margin) (rank_x+1) % size_x,
comm=comms[0])
margin = arr[:2*halo_size] margin = arr[:2*halo_size]
margin, token = mpi4jax.sendrecv(margin, margin, rank_x+1, rank_x-1, right, token = mpi4jax.sendrecv(margin, margin,
comm=comms[0], token=token) (rank_x+1) % size_x,
arr = arr.at[-2*halo_size:].add(margin) (rank_x-1) % size_x,
comm=comms[0], token=token)
arr = arr.at[:2*halo_size].add(left)
arr = arr.at[-2*halo_size:].add(right)
# Perform halo exchange along y # Perform halo exchange along y
rank_y = comms[1].Get_rank() rank_y = comms[1].Get_rank()
size_y = comms[1].Get_size()
margin = arr[:, -2*halo_size:] margin = arr[:, -2*halo_size:]
margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1, left, token = mpi4jax.sendrecv(margin, margin,
comm=comms[1], token=token) (rank_y-1) % size_y,
arr = arr.at[:, :2*halo_size].add(margin) (rank_y+1) % size_y,
comm=comms[1], token=token)
margin = arr[:, :2*halo_size] margin = arr[:, :2*halo_size]
margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1, right, token = mpi4jax.sendrecv(margin, margin,
comm=comms[1], token=token) (rank_y+1) % size_y,
arr = arr.at[:, -2*halo_size:].add(margin) (rank_y-1) % size_y,
comm=comms[1], token=token)
arr = arr.at[:, :2*halo_size].add(left)
arr = arr.at[:, -2*halo_size:].add(right)
return arr return arr
def meshgrid3d(shape, comms=None):
if comms is not None:
nx = comms[0].Get_size()
ny = comms[1].Get_size()
coords = [jnp.arange(shape[0]//nx),
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
else:
coords = [jnp.arange(s) for s in shape[2:]]
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, comms=None): def zeros(shape, comms=None):
""" Initialize an array of given global shape """ Initialize an array of given global shape
partitionned if need be accross dimensions. partitionned if need be accross dimensions.

View file

@ -6,7 +6,7 @@ from jaxpm.ops import halo_reduce
from jaxpm.kernels import fftk, cic_compensation from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, halo_size=0, token=None, comms=None): def cic_paint(mesh, positions, halo_size=0, comms=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
@ -43,11 +43,11 @@ def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
if comms == None: if comms == None:
return mesh return mesh
else: else:
mesh, token = halo_reduce(mesh, halo_size, token, comms) mesh = halo_reduce(mesh, halo_size, comms)
return mesh[halo_size:-halo_size, halo_size:-halo_size] return mesh[halo_size:-halo_size, halo_size:-halo_size]
def cic_read(mesh, positions, halo_size=0, token=None, comms=None): def cic_read(mesh, positions, halo_size=0, comms=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
@ -59,7 +59,7 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
mesh = jnp.pad(mesh, [[halo_size, halo_size], mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size], [halo_size, halo_size],
[0, 0]]) [0, 0]])
mesh, token = halo_reduce(mesh, halo_size, token, comms) mesh = halo_reduce(mesh, halo_size, comms)
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3]) positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
@ -75,14 +75,9 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
neighboor_coords = jnp.mod( neighboor_coords = jnp.mod(
neighboor_coords.astype('int32'), jnp.array(mesh.shape)) neighboor_coords.astype('int32'), jnp.array(mesh.shape))
res = (mesh[neighboor_coords[..., 0], return (mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1], neighboor_coords[..., 1],
neighboor_coords[..., 3]]*kernel).sum(axis=-1) neighboor_coords[..., 3]]*kernel).sum(axis=-1)
if comms is not None:
return res
else:
return res, token
def cic_paint_2d(mesh, positions, weight): def cic_paint_2d(mesh, positions, weight):

View file

@ -1,107 +1,99 @@
import jax import jax
from jax.experimental.maps import xmap
import jax.numpy as jnp import jax.numpy as jnp
import jax_cosmo as jc import jax_cosmo as jc
from jaxpm.ops import fft3d, ifft3d, zeros from jaxpm.ops import fft3d, ifft3d, zeros, normal
from jaxpm.kernels import fftk, apply_gradient_laplace from jaxpm.kernels import fftk, apply_gradient_laplace
from jaxpm.painting import cic_paint, cic_read from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None): def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None):
""" """
Computes gravitational forces on particles using a PM scheme Computes gravitational forces on particles using a PM scheme
""" """
if mesh_shape is None:
mesh_shape = delta_k.shape
kvec = fftk(mesh_shape, comms=comms)
if delta_k is None: if delta_k is None:
delta, token = cic_paint(zeros(mesh_shape,comms=comms), delta = cic_paint(zeros(mesh_shape, comms=comms),
positions, positions,
halo_size=halo_size, token=token, comms=comms) halo_size=halo_size, comms=comms)
delta_k, token = fft3d(delta, token=token, comms=comms) delta_k = fft3d(delta, comms=comms)
# Computes gravitational potential
forces_k = apply_gradient_laplace(kfield, kvec)
# Computes gravitational forces # Computes gravitational forces
fx, token = ifft3d(forces_k[...,0], token=token, comms=comms) kvec = fftk(delta_k.shape, symmetric=False, comms=comms)
fx, token = cic_read(fx, positions, halo_size=halo_size, comms=comms) forces_k = apply_gradient_laplace(delta_k, kvec)
fy, token = ifft3d(forces_k[...,1], token=token, comms=comms) # Interpolate forces at the position of particles
fy, token = cic_read(fy, positions, halo_size=halo_size, comms=comms) return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real,
positions, halo_size=halo_size, comms=comms)
for i in range(3)], axis=-1)
fz, token = ifft3d(forces_k[...,2], token=token, comms=comms)
fz, token = cic_read(fz, positions, halo_size=halo_size, comms=comms)
return jnp.stack([fx,fy,fz],axis=-1), token def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
def lpt(cosmo, initial_conditions, positions, a, token=token, comms=comms):
""" """
Computes first order LPT displacement Computes first order LPT displacement
""" """
initial_force = pm_forces(positions, delta=initial_conditions, token=token, comms=comms) initial_force = pm_forces(
positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx p = a**2 * growth_rate(cosmo, a) * \
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
return dx, p, f, comms f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
dGfa(cosmo, a) * initial_force
return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed):
def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
""" """
Generate initial conditions. Generate initial conditions in Fourier space.
""" """
kvec = fftk(mesh_shape) # Sample normal field
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5 field = normal(key, mesh_shape, comms=comms)
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
field = jax.random.normal(seed, mesh_shape) # Transform to Fourier space
field = jnp.fft.rfftn(field) * pkmesh**0.5 kfield = fft3d(field, comms=comms)
field = jnp.fft.irfftn(field)
return field # Rescaling k to physical units
kvec = [k / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(kfield.shape,
symmetric=False,
comms=comms))]
# Evaluating linear matter powerspectrum
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
) / (box_size[0] * box_size[1] * box_size[2])
# Multipliyng the field by the proper power spectrum
kfield = xmap(lambda kfield, kx, ky, kz:
kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2),
k, jnp.sqrt(pk)),
in_axes=(('x', 'y', ...), ['x'], ['y'], [...]),
out_axes=('x', 'y', ...))(kfield, kvec[0], kvec[1], kvec[2])
return kfield
def make_ode_fn(mesh_shape, halo_size=0, comms=None):
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo): def nbody_ode(state, a, cosmo):
""" """
state is a tuple (position, velocities) state is a tuple (position, velocities)
""" """
pos, vel = state pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m forces = pm_forces(pos, mesh_shape=mesh_shape,
halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift) # Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick) # Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel return dpos, dvel
return nbody_ode return nbody_ode
def pgd_correction(pos, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
args:
pos: particle positions [npart, 3]
params: [alpha, kl, ks] pgd parameters
"""
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range=PGD_kernel(kvec, kl, ks)
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
for i in range(3)],axis=-1)
dpos_pgd = forces_pgd*alpha
return dpos_pgd

78
scripts/test_nbody.py Normal file
View file

@ -0,0 +1,78 @@
from dataclasses import fields
from mpi4py import MPI
import jax
import jax.numpy as jnp
import numpy as onp
import mpi4jax
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros
from jaxpm.pm import linear_field, lpt, make_ode_fn
from jaxpm.painting import cic_paint
from jax.experimental.ode import odeint
import jax_cosmo as jc
### Setting up a whole bunch of things #######
# Create communicators
world = MPI.COMM_WORLD
rank = world.Get_rank()
size = world.Get_size()
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
periods=[True, True])
comms = [cart_comm.Sub([True, False]),
cart_comm.Sub([False, True])]
# Setup random keys
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
################################################
# Size and parameters of the simulation volume
N = 256
mesh_shape = [N, N, N]
box_size = [205, 205, 205] # Mpc/h
cosmo = jc.Planck15()
halo_size = 16
a = 0.1
@jax.jit
def run_sim(cosmo, key):
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
comms=comms)
init_field = ifft3d(initial_conditions, comms=comms).real
# Initialize particles
pos = meshgrid3d(mesh_shape, comms=comms)
# Initial displacement by LPT
cosmo = jc.Planck15()
dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms)
# And now, we run an actual nbody
res = odeint(make_ode_fn(mesh_shape, halo_size, comms),
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
rtol=1e-5, atol=1e-5)
# Painting on a new mesh
field = cic_paint(zeros(mesh_shape, comms=comms),
res[0][-1], halo_size, comms=comms)
return init_field, field
# Recover the real space initial conditions
init_field, field = run_sim(cosmo, key)
# Testing that the result is actually looking like what we expect
total_array, token = mpi4jax.allgather(field, comm=comms[0])
total_array = total_array.reshape([N, N//2, N])
total_array, token = mpi4jax.allgather(
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
total_array = total_array.reshape([N, N, N])
total_array = total_array.transpose([1, 0, 2])
if rank == 0:
onp.save('simulation.npy', total_array)
print('Done !')