mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
fixed a whole lot of issues
This commit is contained in:
parent
429813ad92
commit
72ae0fd88f
5 changed files with 251 additions and 155 deletions
|
@ -1,7 +1,11 @@
|
|||
import jax
|
||||
from jax.experimental.maps import xmap
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from functools import partial
|
||||
|
||||
def fftk(shape, symmetric=True, dtype=np.float32, comms=None):
|
||||
|
||||
def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
|
||||
""" Return k_vector given a shape (nc, nc, nc)
|
||||
"""
|
||||
k = []
|
||||
|
@ -20,19 +24,20 @@ def fftk(shape, symmetric=True, dtype=np.float32, comms=None):
|
|||
if symmetric and d == len(shape) - 1:
|
||||
kd = kd[:shape[d] // 2 + 1]
|
||||
|
||||
if (comms is not None) and d==0:
|
||||
if (comms is not None) and d == 0:
|
||||
kd = kd.reshape([nx, -1])[ix]
|
||||
|
||||
if (comms is not None) and d==1:
|
||||
if (comms is not None) and d == 1:
|
||||
kd = kd.reshape([ny, -1])[iy]
|
||||
|
||||
k.append(kd.astype(dtype))
|
||||
return k
|
||||
|
||||
@partial(jax.pmap,
|
||||
in_axes=[['x','y','z'],
|
||||
['x'],['y'],['z']],
|
||||
out_axes=['x','y','z',...])
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=[['x', 'y', ...],
|
||||
[['x'], ['y'], [...]]],
|
||||
out_axes=['x', 'y', ...])
|
||||
def apply_gradient_laplace(kfield, kvec):
|
||||
kx, ky, kz = kvec
|
||||
kk = (kx**2 + ky**2 + kz**2)
|
||||
|
@ -40,7 +45,8 @@ def apply_gradient_laplace(kfield, kvec):
|
|||
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
|
||||
kfield * kernel * 1j * 1 / 6.0 *
|
||||
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))],axis=-1)
|
||||
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1)
|
||||
|
||||
|
||||
def cic_compensation(kvec):
|
||||
"""
|
||||
|
@ -57,6 +63,7 @@ def cic_compensation(kvec):
|
|||
wts = (kwts[0] * kwts[1] * kwts[2])**(-2)
|
||||
return wts
|
||||
|
||||
|
||||
def PGD_kernel(kvec, kl, ks):
|
||||
"""
|
||||
Computes the PGD kernel
|
||||
|
|
44
jaxpm/ops.py
44
jaxpm/ops.py
|
@ -73,34 +73,58 @@ def ifft3d(arr, comms=None):
|
|||
|
||||
|
||||
def halo_reduce(arr, halo_size, comms=None):
|
||||
if halo_size <= 0:
|
||||
return arr
|
||||
|
||||
# Perform halo exchange along x
|
||||
rank_x = comms[0].Get_rank()
|
||||
size_x = comms[0].Get_size()
|
||||
margin = arr[-2*halo_size:]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1,
|
||||
left, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_x-1) % size_x,
|
||||
(rank_x+1) % size_x,
|
||||
comm=comms[0])
|
||||
arr = arr.at[:2*halo_size].add(margin)
|
||||
|
||||
margin = arr[:2*halo_size]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_x+1, rank_x-1,
|
||||
right, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_x+1) % size_x,
|
||||
(rank_x-1) % size_x,
|
||||
comm=comms[0], token=token)
|
||||
arr = arr.at[-2*halo_size:].add(margin)
|
||||
|
||||
arr = arr.at[:2*halo_size].add(left)
|
||||
arr = arr.at[-2*halo_size:].add(right)
|
||||
|
||||
# Perform halo exchange along y
|
||||
rank_y = comms[1].Get_rank()
|
||||
size_y = comms[1].Get_size()
|
||||
margin = arr[:, -2*halo_size:]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1,
|
||||
left, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_y-1) % size_y,
|
||||
(rank_y+1) % size_y,
|
||||
comm=comms[1], token=token)
|
||||
arr = arr.at[:, :2*halo_size].add(margin)
|
||||
|
||||
margin = arr[:, :2*halo_size]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1,
|
||||
right, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_y+1) % size_y,
|
||||
(rank_y-1) % size_y,
|
||||
comm=comms[1], token=token)
|
||||
arr = arr.at[:, -2*halo_size:].add(margin)
|
||||
arr = arr.at[:, :2*halo_size].add(left)
|
||||
arr = arr.at[:, -2*halo_size:].add(right)
|
||||
|
||||
return arr
|
||||
|
||||
|
||||
def meshgrid3d(shape, comms=None):
|
||||
if comms is not None:
|
||||
nx = comms[0].Get_size()
|
||||
ny = comms[1].Get_size()
|
||||
|
||||
coords = [jnp.arange(shape[0]//nx),
|
||||
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
|
||||
else:
|
||||
coords = [jnp.arange(s) for s in shape[2:]]
|
||||
|
||||
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
||||
|
||||
|
||||
def zeros(shape, comms=None):
|
||||
""" Initialize an array of given global shape
|
||||
partitionned if need be accross dimensions.
|
||||
|
|
|
@ -6,7 +6,7 @@ from jaxpm.ops import halo_reduce
|
|||
from jaxpm.kernels import fftk, cic_compensation
|
||||
|
||||
|
||||
def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
|
||||
def cic_paint(mesh, positions, halo_size=0, comms=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
|
@ -43,11 +43,11 @@ def cic_paint(mesh, positions, halo_size=0, token=None, comms=None):
|
|||
if comms == None:
|
||||
return mesh
|
||||
else:
|
||||
mesh, token = halo_reduce(mesh, halo_size, token, comms)
|
||||
mesh = halo_reduce(mesh, halo_size, comms)
|
||||
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||
|
||||
|
||||
def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
|
||||
def cic_read(mesh, positions, halo_size=0, comms=None):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
|
@ -59,7 +59,7 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
|
|||
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
||||
[halo_size, halo_size],
|
||||
[0, 0]])
|
||||
mesh, token = halo_reduce(mesh, halo_size, token, comms)
|
||||
mesh = halo_reduce(mesh, halo_size, comms)
|
||||
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
||||
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
|
@ -75,15 +75,10 @@ def cic_read(mesh, positions, halo_size=0, token=None, comms=None):
|
|||
neighboor_coords = jnp.mod(
|
||||
neighboor_coords.astype('int32'), jnp.array(mesh.shape))
|
||||
|
||||
res = (mesh[neighboor_coords[..., 0],
|
||||
return (mesh[neighboor_coords[..., 0],
|
||||
neighboor_coords[..., 1],
|
||||
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
|
||||
|
||||
if comms is not None:
|
||||
return res
|
||||
else:
|
||||
return res, token
|
||||
|
||||
|
||||
def cic_paint_2d(mesh, positions, weight):
|
||||
""" Paints positions onto a 2d mesh
|
||||
|
|
112
jaxpm/pm.py
112
jaxpm/pm.py
|
@ -1,68 +1,83 @@
|
|||
import jax
|
||||
from jax.experimental.maps import xmap
|
||||
import jax.numpy as jnp
|
||||
|
||||
import jax_cosmo as jc
|
||||
|
||||
from jaxpm.ops import fft3d, ifft3d, zeros
|
||||
from jaxpm.ops import fft3d, ifft3d, zeros, normal
|
||||
from jaxpm.kernels import fftk, apply_gradient_laplace
|
||||
from jaxpm.painting import cic_paint, cic_read
|
||||
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
||||
|
||||
|
||||
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None):
|
||||
"""
|
||||
Computes gravitational forces on particles using a PM scheme
|
||||
"""
|
||||
if mesh_shape is None:
|
||||
mesh_shape = delta_k.shape
|
||||
|
||||
kvec = fftk(mesh_shape, comms=comms)
|
||||
|
||||
if delta_k is None:
|
||||
delta, token = cic_paint(zeros(mesh_shape,comms=comms),
|
||||
delta = cic_paint(zeros(mesh_shape, comms=comms),
|
||||
positions,
|
||||
halo_size=halo_size, token=token, comms=comms)
|
||||
delta_k, token = fft3d(delta, token=token, comms=comms)
|
||||
|
||||
# Computes gravitational potential
|
||||
forces_k = apply_gradient_laplace(kfield, kvec)
|
||||
halo_size=halo_size, comms=comms)
|
||||
delta_k = fft3d(delta, comms=comms)
|
||||
|
||||
# Computes gravitational forces
|
||||
fx, token = ifft3d(forces_k[...,0], token=token, comms=comms)
|
||||
fx, token = cic_read(fx, positions, halo_size=halo_size, comms=comms)
|
||||
kvec = fftk(delta_k.shape, symmetric=False, comms=comms)
|
||||
forces_k = apply_gradient_laplace(delta_k, kvec)
|
||||
|
||||
fy, token = ifft3d(forces_k[...,1], token=token, comms=comms)
|
||||
fy, token = cic_read(fy, positions, halo_size=halo_size, comms=comms)
|
||||
# Interpolate forces at the position of particles
|
||||
return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real,
|
||||
positions, halo_size=halo_size, comms=comms)
|
||||
for i in range(3)], axis=-1)
|
||||
|
||||
fz, token = ifft3d(forces_k[...,2], token=token, comms=comms)
|
||||
fz, token = cic_read(fz, positions, halo_size=halo_size, comms=comms)
|
||||
|
||||
return jnp.stack([fx,fy,fz],axis=-1), token
|
||||
|
||||
def lpt(cosmo, initial_conditions, positions, a, token=token, comms=comms):
|
||||
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
|
||||
"""
|
||||
Computes first order LPT displacement
|
||||
"""
|
||||
initial_force = pm_forces(positions, delta=initial_conditions, token=token, comms=comms)
|
||||
initial_force = pm_forces(
|
||||
positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms)
|
||||
a = jnp.atleast_1d(a)
|
||||
dx = growth_factor(cosmo, a) * initial_force
|
||||
p = a**2 * growth_rate(cosmo, a) * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
|
||||
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * dGfa(cosmo, a) * initial_force
|
||||
return dx, p, f, comms
|
||||
p = a**2 * growth_rate(cosmo, a) * \
|
||||
jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
|
||||
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
|
||||
dGfa(cosmo, a) * initial_force
|
||||
return dx, p, f
|
||||
|
||||
def linear_field(mesh_shape, box_size, pk, seed):
|
||||
|
||||
def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
|
||||
"""
|
||||
Generate initial conditions.
|
||||
Generate initial conditions in Fourier space.
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
kmesh = sum((kk / box_size[i] * mesh_shape[i])**2 for i, kk in enumerate(kvec))**0.5
|
||||
pkmesh = pk(kmesh) * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]) / (box_size[0] * box_size[1] * box_size[2])
|
||||
# Sample normal field
|
||||
field = normal(key, mesh_shape, comms=comms)
|
||||
|
||||
field = jax.random.normal(seed, mesh_shape)
|
||||
field = jnp.fft.rfftn(field) * pkmesh**0.5
|
||||
field = jnp.fft.irfftn(field)
|
||||
return field
|
||||
# Transform to Fourier space
|
||||
kfield = fft3d(field, comms=comms)
|
||||
|
||||
def make_ode_fn(mesh_shape):
|
||||
# Rescaling k to physical units
|
||||
kvec = [k / box_size[i] * mesh_shape[i]
|
||||
for i, k in enumerate(fftk(kfield.shape,
|
||||
symmetric=False,
|
||||
comms=comms))]
|
||||
|
||||
# Evaluating linear matter powerspectrum
|
||||
k = jnp.logspace(-4, 2, 256)
|
||||
pk = jc.power.linear_matter_power(cosmo, k)
|
||||
pk = pk * (mesh_shape[0] * mesh_shape[1] * mesh_shape[2]
|
||||
) / (box_size[0] * box_size[1] * box_size[2])
|
||||
|
||||
# Multipliyng the field by the proper power spectrum
|
||||
kfield = xmap(lambda kfield, kx, ky, kz:
|
||||
kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2),
|
||||
k, jnp.sqrt(pk)),
|
||||
in_axes=(('x', 'y', ...), ['x'], ['y'], [...]),
|
||||
out_axes=('x', 'y', ...))(kfield, kvec[0], kvec[1], kvec[2])
|
||||
|
||||
return kfield
|
||||
|
||||
|
||||
def make_ode_fn(mesh_shape, halo_size=0, comms=None):
|
||||
|
||||
def nbody_ode(state, a, cosmo):
|
||||
"""
|
||||
|
@ -70,7 +85,8 @@ def make_ode_fn(mesh_shape):
|
|||
"""
|
||||
pos, vel = state
|
||||
|
||||
forces = pm_forces(pos, mesh_shape=mesh_shape) * 1.5 * cosmo.Omega_m
|
||||
forces = pm_forces(pos, mesh_shape=mesh_shape,
|
||||
halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m
|
||||
|
||||
# Computes the update of position (drift)
|
||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||
|
@ -81,27 +97,3 @@ def make_ode_fn(mesh_shape):
|
|||
return dpos, dvel
|
||||
|
||||
return nbody_ode
|
||||
|
||||
|
||||
def pgd_correction(pos, params):
|
||||
"""
|
||||
improve the short-range interactions of PM-Nbody simulations with potential gradient descent method, based on https://arxiv.org/abs/1804.00671
|
||||
args:
|
||||
pos: particle positions [npart, 3]
|
||||
params: [alpha, kl, ks] pgd parameters
|
||||
"""
|
||||
kvec = fftk(mesh_shape)
|
||||
|
||||
delta = cic_paint(jnp.zeros(mesh_shape), pos)
|
||||
alpha, kl, ks = params
|
||||
delta_k = jnp.fft.rfftn(delta)
|
||||
PGD_range=PGD_kernel(kvec, kl, ks)
|
||||
|
||||
pot_k_pgd=(delta_k * laplace_kernel(kvec))*PGD_range
|
||||
|
||||
forces_pgd= jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k_pgd), pos)
|
||||
for i in range(3)],axis=-1)
|
||||
|
||||
dpos_pgd = forces_pgd*alpha
|
||||
|
||||
return dpos_pgd
|
78
scripts/test_nbody.py
Normal file
78
scripts/test_nbody.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
from dataclasses import fields
|
||||
from mpi4py import MPI
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
import mpi4jax
|
||||
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros
|
||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||
from jaxpm.painting import cic_paint
|
||||
from jax.experimental.ode import odeint
|
||||
import jax_cosmo as jc
|
||||
|
||||
|
||||
### Setting up a whole bunch of things #######
|
||||
# Create communicators
|
||||
world = MPI.COMM_WORLD
|
||||
rank = world.Get_rank()
|
||||
size = world.Get_size()
|
||||
|
||||
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
|
||||
periods=[True, True])
|
||||
comms = [cart_comm.Sub([True, False]),
|
||||
cart_comm.Sub([False, True])]
|
||||
|
||||
# Setup random keys
|
||||
master_key = jax.random.PRNGKey(42)
|
||||
key = jax.random.split(master_key, size)[rank]
|
||||
################################################
|
||||
|
||||
# Size and parameters of the simulation volume
|
||||
N = 256
|
||||
mesh_shape = [N, N, N]
|
||||
box_size = [205, 205, 205] # Mpc/h
|
||||
cosmo = jc.Planck15()
|
||||
halo_size = 16
|
||||
a = 0.1
|
||||
|
||||
|
||||
@jax.jit
|
||||
def run_sim(cosmo, key):
|
||||
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
||||
comms=comms)
|
||||
init_field = ifft3d(initial_conditions, comms=comms).real
|
||||
|
||||
# Initialize particles
|
||||
pos = meshgrid3d(mesh_shape, comms=comms)
|
||||
|
||||
# Initial displacement by LPT
|
||||
cosmo = jc.Planck15()
|
||||
dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms)
|
||||
|
||||
# And now, we run an actual nbody
|
||||
res = odeint(make_ode_fn(mesh_shape, halo_size, comms),
|
||||
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||
rtol=1e-5, atol=1e-5)
|
||||
|
||||
# Painting on a new mesh
|
||||
field = cic_paint(zeros(mesh_shape, comms=comms),
|
||||
res[0][-1], halo_size, comms=comms)
|
||||
|
||||
return init_field, field
|
||||
|
||||
|
||||
# Recover the real space initial conditions
|
||||
init_field, field = run_sim(cosmo, key)
|
||||
|
||||
# Testing that the result is actually looking like what we expect
|
||||
total_array, token = mpi4jax.allgather(field, comm=comms[0])
|
||||
total_array = total_array.reshape([N, N//2, N])
|
||||
total_array, token = mpi4jax.allgather(
|
||||
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
|
||||
total_array = total_array.reshape([N, N, N])
|
||||
total_array = total_array.transpose([1, 0, 2])
|
||||
|
||||
if rank == 0:
|
||||
onp.save('simulation.npy', total_array)
|
||||
|
||||
print('Done !')
|
Loading…
Add table
Reference in a new issue