fixed a whole lot of issues

This commit is contained in:
EiffL 2022-10-22 15:58:32 -04:00
parent 429813ad92
commit 72ae0fd88f
5 changed files with 251 additions and 155 deletions

View file

@ -1,84 +1,91 @@
import jax
from jax.experimental.maps import xmap
import numpy as np
import jax.numpy as jnp
from functools import partial
def fftk(shape, symmetric=True, dtype=np.float32, comms=None):
""" Return k_vector given a shape (nc, nc, nc)
"""
k = []
if comms is not None:
nx = comms[0].Get_size()
ix = comms[0].Get_rank()
ny = comms[1].Get_size()
iy = comms[1].Get_rank()
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
""" Return k_vector given a shape (nc, nc, nc)
"""
k = []
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
if comms is not None:
nx = comms[0].Get_size()
ix = comms[0].Get_rank()
ny = comms[1].Get_size()
iy = comms[1].Get_rank()
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
if (comms is not None) and d==0:
kd = kd.reshape([nx, -1])[ix]
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
if (comms is not None) and d==1:
kd = kd.reshape([ny, -1])[iy]
if (comms is not None) and d == 0:
kd = kd.reshape([nx, -1])[ix]
k.append(kd.astype(dtype))
return k
if (comms is not None) and d == 1:
kd = kd.reshape([ny, -1])[iy]
@partial(jax.pmap,
in_axes=[['x','y','z'],
['x'],['y'],['z']],
out_axes=['x','y','z',...])
k.append(kd.astype(dtype))
return k
@partial(xmap,
in_axes=[['x', 'y', ...],
[['x'], ['y'], [...]]],
out_axes=['x', 'y', ...])
def apply_gradient_laplace(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))],axis=-1)
kfield * kernel * 1j * 1 / 6.0 *
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1)
def cic_compensation(kvec):
"""
Computes cic compensation kernel.
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
Itself based on equation 18 (with p=2) of
`Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
Args:
kvec: array of k values in Fourier space
Returns:
v: array of kernel
"""
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts
"""
Computes cic compensation kernel.
Adapted from https://github.com/bccp/nbodykit/blob/a387cf429d8cb4a07bb19e3b4325ffdf279a131e/nbodykit/source/mesh/catalog.py#L499
Itself based on equation 18 (with p=2) of
`Jing et al 2005 <https://arxiv.org/abs/astro-ph/0409240>`_
Args:
kvec: array of k values in Fourier space
Returns:
v: array of kernel
"""
kwts = [np.sinc(kvec[i] / (2 * np.pi)) for i in range(3)]
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
return wts
def PGD_kernel(kvec, kl, ks):
"""
Computes the PGD kernel
Parameters:
-----------
kvec: array
Array of k values in Fourier space
kl: float
initial long range scale parameter
ks: float
initial dhort range scale parameter
Returns:
--------
v: array
kernel
"""
kk = sum(ki**2 for ki in kvec)
kl2 = kl**2
ks4 = ks**4
mask = (kk == 0).nonzero()
kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int)
v *= imask
return v
"""
Computes the PGD kernel
Parameters:
-----------
kvec: array
Array of k values in Fourier space
kl: float
initial long range scale parameter
ks: float
initial dhort range scale parameter
Returns:
--------
v: array
kernel
"""
kk = sum(ki**2 for ki in kvec)
kl2 = kl**2
ks4 = ks**4
mask = (kk == 0).nonzero()
kk[mask] = 1
v = jnp.exp(-kl2 / kk) * jnp.exp(-kk**2 / ks4)
imask = (~(kk == 0)).astype(int)
v *= imask
return v

View file

@ -73,34 +73,58 @@ def ifft3d(arr, comms=None):
def halo_reduce(arr, halo_size, comms=None):
if halo_size <= 0:
return arr
# Perform halo exchange along x
rank_x = comms[0].Get_rank()
size_x = comms[0].Get_size()
margin = arr[-2*halo_size:]
margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1,
comm=comms[0])
arr = arr.at[:2*halo_size].add(margin)
left, token = mpi4jax.sendrecv(margin, margin,
(rank_x-1) % size_x,
(rank_x+1) % size_x,
comm=comms[0])
margin = arr[:2*halo_size]
margin, token = mpi4jax.sendrecv(margin, margin, rank_x+1, rank_x-1,
comm=comms[0], token=token)
arr = arr.at[-2*halo_size:].add(margin)
right, token = mpi4jax.sendrecv(margin, margin,
(rank_x+1) % size_x,
(rank_x-1) % size_x,
comm=comms[0], token=token)
arr = arr.at[:2*halo_size].add(left)
arr = arr.at[-2*halo_size:].add(right)
# Perform halo exchange along y
rank_y = comms[1].Get_rank()
size_y = comms[1].Get_size()
margin = arr[:, -2*halo_size:]
margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1,
comm=comms[1], token=token)
arr = arr.at[:, :2*halo_size].add(margin)
left, token = mpi4jax.sendrecv(margin, margin,
(rank_y-1) % size_y,
(rank_y+1) % size_y,
comm=comms[1], token=token)
margin = arr[:, :2*halo_size]
margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1,
comm=comms[1], token=token)
arr = arr.at[:, -2*halo_size:].add(margin)
right, token = mpi4jax.sendrecv(margin, margin,
(rank_y+1) % size_y,
(rank_y-1) % size_y,
comm=comms[1], token=token)
arr = arr.at[:, :2*halo_size].add(left)
arr = arr.at[:, -2*halo_size:].add(right)
return arr
def meshgrid3d(shape, comms=None):
if comms is not None:
nx = comms[0].Get_size()
ny = comms[1].Get_size()
coords = [jnp.arange(shape[0]//nx),
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
else:
coords = [jnp.arange(s) for s in shape[2:]]
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, comms=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.

View file

@ -6,7 +6,7 @@ from jaxpm.ops import halo_reduce
from jaxpm.kernels import fftk, cic_compensation
def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
def cic_paint(mesh, positions, halo_size=0, comms=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
@ -43,11 +43,11 @@ def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
if comms == None:
return mesh
else:
mesh, token = halo_reduce(mesh, halo_size, token, comms)
mesh = halo_reduce(mesh, halo_size, comms)
return mesh[halo_size:-halo_size, halo_size:-halo_size]
def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
def cic_read(mesh, positions, halo_size=0, comms=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
@ -59,7 +59,7 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
[0, 0]])
mesh, token = halo_reduce(mesh, halo_size, token, comms)
mesh = halo_reduce(mesh, halo_size, comms)
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
@ -75,14 +75,9 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
neighboor_coords = jnp.mod(
neighboor_coords.astype('int32'), jnp.array(mesh.shape))
res = (mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
if comms is not None:
return res
else:
return res, token
return (mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
def cic_paint_2d(mesh, positions, weight):

View file

@ -1,107 +1,99 @@
import jax
from jax.experimental.maps import xmap
import jax.numpy as jnp
import jax_cosmo as jc
from jaxpm.ops import fft3d, ifft3d, zeros
from jaxpm.ops import fft3d, ifft3d, zeros, normal
from jaxpm.kernels import fftk, apply_gradient_laplace
from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None):
"""
Computes gravitational forces on particles using a PM scheme
"""
if mesh_shape is None:
mesh_shape = delta_k.shape
kvec = fftk(mesh_shape, comms=comms)
if delta_k is None:
delta, token = cic_paint(zeros(mesh_shape,comms=comms),
positions,
halo_size=halo_size, token=token, comms=comms)
delta_k, token = fft3d(delta, token=token, comms=comms)
# Computes gravitational potential
forces_k = apply_gradient_laplace(kfield, kvec)
delta = cic_paint(zeros(mesh_shape, comms=comms),
positions,
halo_size=halo_size, comms=comms)
delta_k = fft3d(delta, comms=comms)
# Computes gravitational forces
fx, token = ifft3d(forces_k[...,0], token=token, comms=comms)
fx, token = cic_read(fx, positions, halo_size=halo_size, comms=comms)
kvec = fftk(delta_k.shape, symmetric=False, comms=comms)
forces_k = apply_gradient_laplace(delta_k, kvec)
fy, token = ifft3d(forces_k[...,1], token=token, comms=comms)
fy, token = cic_read(fy, positions, halo_size=halo_size, comms=comms)
# Interpolate forces at the position of particles
return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real,
positions, halo_size=halo_size, comms=comms)
for i in range(3)], axis=-1)
fz, token = ifft3d(forces_k[...,2], token=token, comms=comms)
fz, token = cic_read(fz, positions, halo_size=halo_size, comms=comms)
return jnp.stack([fx,fy,fz],axis=-1), token
def lpt(cosmo, initial_conditions, positions, a, token=token, comms=comms):
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(positions, delta=initial_conditions, token=token, comms=comms)
initial_force = pm_forces(
positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
return dx, p, f, comms
p = a**2 * growth_rate(cosmo, a) * \
jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
dGfa(cosmo, a) * initial_force
return dx, p, f
def linear_field(mesh_shape, box_size, pk, seed):
def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
"""
Generate initial conditions.
Generate initial conditions in Fourier space.
"""
kvec = fftk(mesh_shape)
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
# Sample normal field
field = normal(key, mesh_shape, comms=comms)
field = jax.random.normal(seed, mesh_shape)
field = jnp.fft.rfftn(field) * pkmesh**0.5
field = jnp.fft.irfftn(field)
return field
# Transform to Fourier space
kfield = fft3d(field, comms=comms)
# Rescaling k to physical units
kvec = [k / box_size[i] * mesh_shape[i]
for i, k in enumerate(fftk(kfield.shape,
symmetric=False,
comms=comms))]
# Evaluating linear matter powerspectrum
k = jnp.logspace(-4, 2, 256)
pk = jc.power.linear_matter_power(cosmo, k)
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
) / (box_size[0] * box_size[1] * box_size[2])
# Multipliyng the field by the proper power spectrum
kfield = xmap(lambda kfield, kx, ky, kz:
kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2),
k, jnp.sqrt(pk)),
in_axes=(('x', 'y', ...), ['x'], ['y'], [...]),
out_axes=('x', 'y', ...))(kfield, kvec[0], kvec[1], kvec[2])
return kfield
def make_ode_fn(mesh_shape, halo_size=0, comms=None):
def make_ode_fn(mesh_shape):
def nbody_ode(state, a, cosmo):
"""
state is a tuple (position, velocities)
"""
pos, vel = state
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
forces = pm_forces(pos, mesh_shape=mesh_shape,
halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m
# Computes the update of position (drift)
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
# Computes the update of velocity (kick)
dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces
return dpos, dvel
return nbody_ode
def pgd_correction(pos, params):
"""
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
args:
pos: particle positions [npart, 3]
params: [alpha, kl, ks] pgd parameters
"""
kvec = fftk(mesh_shape)
delta = cic_paint(jnp.zeros(mesh_shape), pos)
alpha, kl, ks = params
delta_k = jnp.fft.rfftn(delta)
PGD_range=PGD_kernel(kvec, kl, ks)
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
for i in range(3)],axis=-1)
dpos_pgd = forces_pgd*alpha
return dpos_pgd