mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
Adding an example of jaxdecomp implementation
This commit is contained in:
parent
6644b35d71
commit
6ca4c9191e
5 changed files with 166 additions and 192 deletions
|
@ -3,19 +3,23 @@ from jax.experimental.maps import xmap
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import jaxdecomp
|
||||||
|
|
||||||
|
def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
|
||||||
def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
|
|
||||||
""" Return k_vector given a shape (nc, nc, nc)
|
""" Return k_vector given a shape (nc, nc, nc)
|
||||||
"""
|
"""
|
||||||
k = []
|
k = []
|
||||||
|
|
||||||
if comms is not None:
|
if sharding_info is not None:
|
||||||
nx = comms[0].Get_size()
|
nx = sharding_info.pdims[1]
|
||||||
ix = comms[0].Get_rank()
|
ny = sharding_info.pdims[0]
|
||||||
ny = comms[1].Get_size()
|
# nx = sharding_info[0].Get_size()
|
||||||
iy = comms[1].Get_rank()
|
# ix = sharding_info[0].Get_rank()
|
||||||
shape = [shape[0]*nx, shape[1]*ny] + list(shape[2:])
|
# ny = sharding_info[1].Get_size()
|
||||||
|
# iy = sharding_info[1].Get_rank()
|
||||||
|
ix = sharding_info.rank
|
||||||
|
iy = 0
|
||||||
|
shape = sharding_info.global_shape
|
||||||
|
|
||||||
for d in range(len(shape)):
|
for d in range(len(shape)):
|
||||||
kd = np.fft.fftfreq(shape[d])
|
kd = np.fft.fftfreq(shape[d])
|
||||||
|
@ -24,10 +28,10 @@ def fftk(shape, symmetric=False, dtype=np.float32, comms=None):
|
||||||
if symmetric and d == len(shape) - 1:
|
if symmetric and d == len(shape) - 1:
|
||||||
kd = kd[:shape[d] // 2 + 1]
|
kd = kd[:shape[d] // 2 + 1]
|
||||||
|
|
||||||
if (comms is not None) and d == 0:
|
if (sharding_info is not None) and d == 0:
|
||||||
kd = kd.reshape([nx, -1])[ix]
|
kd = kd.reshape([nx, -1])[ix]
|
||||||
|
|
||||||
if (comms is not None) and d == 1:
|
if (sharding_info is not None) and d == 1:
|
||||||
kd = kd.reshape([ny, -1])[iy]
|
kd = kd.reshape([ny, -1])[iy]
|
||||||
|
|
||||||
k.append(kd.astype(dtype))
|
k.append(kd.astype(dtype))
|
||||||
|
@ -42,10 +46,9 @@ def apply_gradient_laplace(kfield, kvec):
|
||||||
kx, ky, kz = kvec
|
kx, ky, kz = kvec
|
||||||
kk = (kx**2 + ky**2 + kz**2)
|
kk = (kx**2 + ky**2 + kz**2)
|
||||||
kernel = jnp.where(kk == 0, 1., 1./kk)
|
kernel = jnp.where(kk == 0, 1., 1./kk)
|
||||||
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky)),
|
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||||
kfield * kernel * 1j * 1 / 6.0 *
|
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx)),
|
||||||
(8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(ky) - jnp.sin(2 * ky))], axis=-1)
|
||||||
kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kx) - jnp.sin(2 * kx))], axis=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def cic_compensation(kvec):
|
def cic_compensation(kvec):
|
||||||
|
|
176
jaxpm/ops.py
176
jaxpm/ops.py
|
@ -2,155 +2,91 @@
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import mpi4jax
|
import mpi4jax
|
||||||
|
import jaxdecomp
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShardingInfo:
|
||||||
|
"""Class for keeping track of the distribution strategy"""
|
||||||
|
global_shape: Tuple[int, int, int]
|
||||||
|
pdims: Tuple[int, int]
|
||||||
|
halo_extents: Tuple[int, int, int]
|
||||||
|
rank: int = 0
|
||||||
|
|
||||||
|
|
||||||
def fft3d(arr, comms=None):
|
def fft3d(arr, sharding_info=None):
|
||||||
""" Computes forward FFT, note that the output is transposed
|
""" Computes forward FFT, note that the output is transposed
|
||||||
"""
|
"""
|
||||||
if comms is not None:
|
if sharding_info is None:
|
||||||
shape = list(arr.shape)
|
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
|
||||||
nx = comms[0].Get_size()
|
|
||||||
ny = comms[1].Get_size()
|
|
||||||
|
|
||||||
# First FFT along z
|
|
||||||
arr = jnp.fft.fft(arr) # [x, y, z]
|
|
||||||
# Perform single gpu or distributed transpose
|
|
||||||
if comms == None:
|
|
||||||
arr = arr.transpose([1, 2, 0])
|
|
||||||
else:
|
else:
|
||||||
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
arr = jaxdecomp.pfft3d(arr,
|
||||||
#arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
|
pdims=sharding_info.pdims,
|
||||||
arr = jnp.einsum('ij,xyjz->iyzx', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
global_shape=sharding_info.global_shape)
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
|
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
|
|
||||||
|
|
||||||
# Second FFT along x
|
|
||||||
arr = jnp.fft.fft(arr)
|
|
||||||
# Perform single gpu or distributed transpose
|
|
||||||
if comms == None:
|
|
||||||
arr = arr.transpose([1, 2, 0])
|
|
||||||
else:
|
|
||||||
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
|
||||||
#arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
|
|
||||||
arr = jnp.einsum('ij,yzjx->izxy', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
|
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
|
|
||||||
|
|
||||||
# Third FFT along y
|
|
||||||
return jnp.fft.fft(arr)
|
|
||||||
|
|
||||||
|
|
||||||
def ifft3d(arr, comms=None):
|
|
||||||
""" Let's assume that the data is distributed accross x
|
|
||||||
"""
|
|
||||||
if comms is not None:
|
|
||||||
shape = list(arr.shape)
|
|
||||||
nx = comms[0].Get_size()
|
|
||||||
ny = comms[1].Get_size()
|
|
||||||
|
|
||||||
# First FFT along y
|
|
||||||
arr = jnp.fft.ifft(arr) # Now [z, x, y]
|
|
||||||
# Perform single gpu or distributed transpose
|
|
||||||
if comms == None:
|
|
||||||
arr = arr.transpose([0, 2, 1])
|
|
||||||
else:
|
|
||||||
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
|
||||||
# arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
|
|
||||||
arr = jnp.einsum('ij,zxjy->izyx', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
|
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
|
|
||||||
|
|
||||||
# Second FFT along x
|
|
||||||
arr = jnp.fft.ifft(arr)
|
|
||||||
# Perform single gpu or distributed transpose
|
|
||||||
if comms == None:
|
|
||||||
arr = arr.transpose([2, 1, 0])
|
|
||||||
else:
|
|
||||||
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
|
||||||
# arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
|
|
||||||
arr = jnp.einsum('ij,zyjx->ixyz', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
|
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
|
|
||||||
|
|
||||||
# Third FFT along z
|
|
||||||
return jnp.fft.ifft(arr)
|
|
||||||
|
|
||||||
|
|
||||||
def halo_reduce(arr, halo_size, comms=None):
|
|
||||||
if halo_size <= 0:
|
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
# Perform halo exchange along x
|
|
||||||
rank_x = comms[0].Get_rank()
|
|
||||||
size_x = comms[0].Get_size()
|
|
||||||
margin = arr[-2*halo_size:]
|
|
||||||
left, token = mpi4jax.sendrecv(margin, margin,
|
|
||||||
(rank_x-1) % size_x,
|
|
||||||
(rank_x+1) % size_x,
|
|
||||||
comm=comms[0])
|
|
||||||
margin = arr[:2*halo_size]
|
|
||||||
right, token = mpi4jax.sendrecv(margin, margin,
|
|
||||||
(rank_x+1) % size_x,
|
|
||||||
(rank_x-1) % size_x,
|
|
||||||
comm=comms[0], token=token)
|
|
||||||
|
|
||||||
arr = arr.at[:2*halo_size].add(left)
|
def ifft3d(arr, sharding_info=None):
|
||||||
arr = arr.at[-2*halo_size:].add(right)
|
if sharding_info is None:
|
||||||
|
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
|
||||||
|
else:
|
||||||
|
arr = jaxdecomp.pifft3d(arr,
|
||||||
|
pdims=sharding_info.pdims,
|
||||||
|
global_shape=sharding_info.global_shape)
|
||||||
|
return arr
|
||||||
|
|
||||||
# Perform halo exchange along y
|
|
||||||
rank_y = comms[1].Get_rank()
|
def halo_reduce(arr, sharding_info=None):
|
||||||
size_y = comms[1].Get_size()
|
if sharding_info is None:
|
||||||
margin = arr[:, -2*halo_size:]
|
return arr
|
||||||
left, token = mpi4jax.sendrecv(margin, margin,
|
halo_size = sharding_info.halo_extents[0]
|
||||||
(rank_y-1) % size_y,
|
global_shape = sharding_info.global_shape
|
||||||
(rank_y+1) % size_y,
|
arr = jaxdecomp.halo_exchange(arr,
|
||||||
comm=comms[1], token=token)
|
halo_extents=(halo_size//2, halo_size//2, 0),
|
||||||
margin = arr[:, :2*halo_size]
|
halo_periods=(True,True,True),
|
||||||
right, token = mpi4jax.sendrecv(margin, margin,
|
pdims=sharding_info.pdims,
|
||||||
(rank_y+1) % size_y,
|
global_shape=(global_shape[0]+2*halo_size,
|
||||||
(rank_y-1) % size_y,
|
global_shape[1]+halo_size,
|
||||||
comm=comms[1], token=token)
|
global_shape[2]))
|
||||||
arr = arr.at[:, :2*halo_size].add(left)
|
|
||||||
arr = arr.at[:, -2*halo_size:].add(right)
|
# Apply correction along x
|
||||||
|
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
||||||
|
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
|
||||||
|
|
||||||
|
# Apply correction along y
|
||||||
|
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
|
||||||
|
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
|
||||||
|
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def meshgrid3d(shape, comms=None):
|
def meshgrid3d(shape, sharding_info=None):
|
||||||
if comms is not None:
|
if sharding_info is not None:
|
||||||
nx = comms[0].Get_size()
|
coords = [jnp.arange(sharding_info.global_shape[0]//sharding_info.pdims[1]),
|
||||||
ny = comms[1].Get_size()
|
jnp.arange(sharding_info.global_shape[1]//sharding_info.pdims[0]), jnp.arange(sharding_info.global_shape[2])]
|
||||||
|
|
||||||
coords = [jnp.arange(shape[0]//nx),
|
|
||||||
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
|
|
||||||
else:
|
else:
|
||||||
coords = [jnp.arange(s) for s in shape[2:]]
|
coords = [jnp.arange(s) for s in shape[2:]]
|
||||||
|
|
||||||
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
||||||
|
|
||||||
|
|
||||||
def zeros(shape, comms=None):
|
def zeros(shape, sharding_info=None):
|
||||||
""" Initialize an array of given global shape
|
""" Initialize an array of given global shape
|
||||||
partitionned if need be accross dimensions.
|
partitionned if need be accross dimensions.
|
||||||
"""
|
"""
|
||||||
if comms is None:
|
if sharding_info is None:
|
||||||
return jnp.zeros(shape)
|
return jnp.zeros(shape)
|
||||||
|
|
||||||
nx = comms[0].Get_size()
|
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
|
||||||
ny = comms[1].Get_size()
|
|
||||||
|
|
||||||
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
|
|
||||||
|
|
||||||
|
|
||||||
def normal(key, shape, comms=None):
|
def normal(key, shape, sharding_info=None):
|
||||||
""" Generates a normal variable for the given
|
""" Generates a normal variable for the given
|
||||||
global shape.
|
global shape.
|
||||||
"""
|
"""
|
||||||
if comms is None:
|
if sharding_info is None:
|
||||||
return jax.random.normal(key, shape)
|
return jax.random.normal(key, shape)
|
||||||
|
|
||||||
nx = comms[0].Get_size()
|
|
||||||
ny = comms[1].Get_size()
|
|
||||||
|
|
||||||
return jax.random.normal(key,
|
return jax.random.normal(key,
|
||||||
[shape[0]//nx, shape[1]//ny]+list(shape[2:]))
|
[sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0], sharding_info.global_shape[2]])
|
||||||
|
|
|
@ -6,12 +6,12 @@ from jaxpm.ops import halo_reduce
|
||||||
from jaxpm.kernels import fftk, cic_compensation
|
from jaxpm.kernels import fftk, cic_compensation
|
||||||
|
|
||||||
|
|
||||||
def cic_paint(mesh, positions, halo_size=0, comms=None):
|
def cic_paint(mesh, positions, halo_size=0, sharding_info=None):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
positions: [npart, 3]
|
||||||
"""
|
"""
|
||||||
if comms is not None:
|
if sharding_info is not None:
|
||||||
# Add some padding for the halo exchange
|
# Add some padding for the halo exchange
|
||||||
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
||||||
[halo_size, halo_size],
|
[halo_size, halo_size],
|
||||||
|
@ -40,26 +40,32 @@ def cic_paint(mesh, positions, halo_size=0, comms=None):
|
||||||
kernel.reshape([-1, 8]),
|
kernel.reshape([-1, 8]),
|
||||||
dnums)
|
dnums)
|
||||||
|
|
||||||
if comms == None:
|
if sharding_info == None:
|
||||||
return mesh
|
return mesh
|
||||||
else:
|
else:
|
||||||
mesh = halo_reduce(mesh, halo_size, comms)
|
mesh = halo_reduce(mesh, sharding_info)
|
||||||
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||||
|
|
||||||
|
|
||||||
def cic_read(mesh, positions, halo_size=0, comms=None):
|
def cic_read(mesh, positions, halo_size=0, sharding_info=None):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
positions: [npart, 3]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if comms is not None:
|
if sharding_info is not None:
|
||||||
# Add some padding and perfom hao exchange to retrieve
|
# Add some padding and perfom hao exchange to retrieve
|
||||||
# neighboring regions
|
# neighboring regions
|
||||||
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
||||||
[halo_size, halo_size],
|
[halo_size, halo_size],
|
||||||
[0, 0]])
|
[0, 0]])
|
||||||
mesh = halo_reduce(mesh, halo_size, comms)
|
# mesh = halo_reduce(mesh, sharding_info)
|
||||||
|
import jaxdecomp
|
||||||
|
mesh = jaxdecomp.halo_exchange(mesh,
|
||||||
|
halo_extents=sharding_info.halo_extents,
|
||||||
|
halo_periods=(True,True,True),
|
||||||
|
pdims=sharding_info.pdims,
|
||||||
|
global_shape=sharding_info.global_shape)
|
||||||
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
||||||
|
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
|
|
30
jaxpm/pm.py
30
jaxpm/pm.py
|
@ -10,32 +10,32 @@ from jaxpm.painting import cic_paint, cic_read
|
||||||
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
||||||
|
|
||||||
|
|
||||||
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, token=None, comms=None):
|
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
|
||||||
"""
|
"""
|
||||||
Computes gravitational forces on particles using a PM scheme
|
Computes gravitational forces on particles using a PM scheme
|
||||||
"""
|
"""
|
||||||
if delta_k is None:
|
if delta_k is None:
|
||||||
delta = cic_paint(zeros(mesh_shape, comms=comms),
|
delta = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
||||||
positions,
|
positions,
|
||||||
halo_size=halo_size, comms=comms)
|
halo_size=halo_size, sharding_info=sharding_info)
|
||||||
delta_k = fft3d(delta, comms=comms)
|
delta_k = fft3d(delta, sharding_info=sharding_info)
|
||||||
|
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
kvec = fftk(delta_k.shape, symmetric=False, comms=comms)
|
kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info)
|
||||||
forces_k = apply_gradient_laplace(delta_k, kvec)
|
forces_k = apply_gradient_laplace(delta_k, kvec)
|
||||||
|
|
||||||
# Interpolate forces at the position of particles
|
# Interpolate forces at the position of particles
|
||||||
return jnp.stack([cic_read(ifft3d(forces_k[..., i], comms=comms).real,
|
return jnp.stack([cic_read(ifft3d(forces_k[..., i], sharding_info=sharding_info).real,
|
||||||
positions, halo_size=halo_size, comms=comms)
|
positions, halo_size=halo_size, sharding_info=sharding_info)
|
||||||
for i in range(3)], axis=-1)
|
for i in range(3)], axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
|
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
|
||||||
"""
|
"""
|
||||||
Computes first order LPT displacement
|
Computes first order LPT displacement
|
||||||
"""
|
"""
|
||||||
initial_force = pm_forces(
|
initial_force = pm_forces(
|
||||||
positions, delta_k=initial_conditions, halo_size=halo_size, comms=comms)
|
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
dx = growth_factor(cosmo, a) * initial_force
|
dx = growth_factor(cosmo, a) * initial_force
|
||||||
p = a**2 * growth_rate(cosmo, a) * \
|
p = a**2 * growth_rate(cosmo, a) * \
|
||||||
|
@ -45,21 +45,21 @@ def lpt(cosmo, positions, initial_conditions, a, halo_size=0, comms=None):
|
||||||
return dx, p, f
|
return dx, p, f
|
||||||
|
|
||||||
|
|
||||||
def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
|
def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None):
|
||||||
"""
|
"""
|
||||||
Generate initial conditions in Fourier space.
|
Generate initial conditions in Fourier space.
|
||||||
"""
|
"""
|
||||||
# Sample normal field
|
# Sample normal field
|
||||||
field = normal(key, mesh_shape, comms=comms)
|
field = normal(key, mesh_shape, sharding_info=sharding_info)
|
||||||
|
|
||||||
# Transform to Fourier space
|
# Transform to Fourier space
|
||||||
kfield = fft3d(field, comms=comms)
|
kfield = fft3d(field, sharding_info=sharding_info)
|
||||||
|
|
||||||
# Rescaling k to physical units
|
# Rescaling k to physical units
|
||||||
kvec = [k / box_size[i] * mesh_shape[i]
|
kvec = [k / box_size[i] * mesh_shape[i]
|
||||||
for i, k in enumerate(fftk(kfield.shape,
|
for i, k in enumerate(fftk(kfield.shape,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
comms=comms))]
|
sharding_info=sharding_info))]
|
||||||
|
|
||||||
# Evaluating linear matter powerspectrum
|
# Evaluating linear matter powerspectrum
|
||||||
k = jnp.logspace(-4, 2, 256)
|
k = jnp.logspace(-4, 2, 256)
|
||||||
|
@ -77,7 +77,7 @@ def linear_field(cosmo, mesh_shape, box_size, key, comms=None):
|
||||||
return kfield
|
return kfield
|
||||||
|
|
||||||
|
|
||||||
def make_ode_fn(mesh_shape, halo_size=0, comms=None):
|
def make_ode_fn(mesh_shape, halo_size=0, sharding_info=None):
|
||||||
|
|
||||||
def nbody_ode(state, a, cosmo):
|
def nbody_ode(state, a, cosmo):
|
||||||
"""
|
"""
|
||||||
|
@ -86,7 +86,7 @@ def make_ode_fn(mesh_shape, halo_size=0, comms=None):
|
||||||
pos, vel = state
|
pos, vel = state
|
||||||
|
|
||||||
forces = pm_forces(pos, mesh_shape=mesh_shape,
|
forces = pm_forces(pos, mesh_shape=mesh_shape,
|
||||||
halo_size=halo_size, comms=comms) * 1.5 * cosmo.Omega_m
|
halo_size=halo_size, sharding_info=sharding_info) * 1.5 * cosmo.Omega_m
|
||||||
|
|
||||||
# Computes the update of position (drift)
|
# Computes the update of position (drift)
|
||||||
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
from dataclasses import fields
|
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
|
import os
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as onp
|
import numpy as onp
|
||||||
import mpi4jax
|
import jaxdecomp
|
||||||
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros
|
from jaxpm.ops import fft3d, ifft3d, normal, meshgrid3d, zeros, ShardingInfo
|
||||||
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
from jaxpm.pm import linear_field, lpt, make_ode_fn
|
||||||
from jaxpm.painting import cic_paint
|
from jaxpm.painting import cic_paint
|
||||||
from jax.experimental.ode import odeint
|
from jax.experimental.ode import odeint
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
|
import time
|
||||||
|
|
||||||
### Setting up a whole bunch of things #######
|
### Setting up a whole bunch of things #######
|
||||||
# Create communicators
|
# Create communicators
|
||||||
|
@ -17,10 +17,12 @@ world = MPI.COMM_WORLD
|
||||||
rank = world.Get_rank()
|
rank = world.Get_rank()
|
||||||
size = world.Get_size()
|
size = world.Get_size()
|
||||||
|
|
||||||
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
|
# Here we assume clients are on the same node, so we restrict which device
|
||||||
periods=[True, True])
|
# they can use based on their rank
|
||||||
comms = [cart_comm.Sub([True, False]),
|
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1)
|
||||||
cart_comm.Sub([False, True])]
|
|
||||||
|
|
||||||
|
jaxdecomp.init()
|
||||||
|
|
||||||
# Setup random keys
|
# Setup random keys
|
||||||
master_key = jax.random.PRNGKey(42)
|
master_key = jax.random.PRNGKey(42)
|
||||||
|
@ -29,50 +31,77 @@ key = jax.random.split(master_key, size)[rank]
|
||||||
|
|
||||||
# Size and parameters of the simulation volume
|
# Size and parameters of the simulation volume
|
||||||
N = 256
|
N = 256
|
||||||
mesh_shape = [N, N, N]
|
mesh_shape = (N, N, N)
|
||||||
box_size = [205, 205, 205] # Mpc/h
|
box_size = [500, 500, 500] # Mpc/h
|
||||||
|
halo_size = 32
|
||||||
|
sharding_info = ShardingInfo(global_shape=mesh_shape,
|
||||||
|
pdims=(1,2),
|
||||||
|
halo_extents=(halo_size, halo_size, 0),
|
||||||
|
rank=rank)
|
||||||
cosmo = jc.Planck15()
|
cosmo = jc.Planck15()
|
||||||
halo_size = 16
|
|
||||||
a = 0.1
|
a = 0.1
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def run_sim(cosmo, key):
|
def run_sim(cosmo, key):
|
||||||
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
||||||
comms=comms)
|
sharding_info=sharding_info)
|
||||||
init_field = ifft3d(initial_conditions, comms=comms).real
|
|
||||||
|
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||||
|
|
||||||
# Initialize particles
|
# Initialize particles
|
||||||
pos = meshgrid3d(mesh_shape, comms=comms)
|
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
||||||
|
|
||||||
# Initial displacement by LPT
|
# Initial displacement by LPT
|
||||||
cosmo = jc.Planck15()
|
cosmo = jc.Planck15()
|
||||||
dx, p, f = lpt(cosmo, pos, initial_conditions, a, comms=comms)
|
dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
|
||||||
|
|
||||||
# And now, we run an actual nbody
|
# And now, we run an actual nbody
|
||||||
res = odeint(make_ode_fn(mesh_shape, halo_size, comms),
|
res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
|
||||||
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||||
rtol=1e-5, atol=1e-5)
|
rtol=1e-3, atol=1e-3)
|
||||||
|
|
||||||
# Painting on a new mesh
|
# Painting on a new mesh
|
||||||
field = cic_paint(zeros(mesh_shape, comms=comms),
|
field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
||||||
res[0][-1], halo_size, comms=comms)
|
res[0][-1], halo_size, sharding_info=sharding_info)
|
||||||
|
|
||||||
|
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
||||||
|
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||||
return init_field, field
|
return init_field, field
|
||||||
|
|
||||||
|
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
||||||
|
# sharding_info=sharding_info)
|
||||||
|
|
||||||
# Recover the real space initial conditions
|
# init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||||
|
|
||||||
|
# print("hello", init_field.shape)
|
||||||
|
|
||||||
|
# cosmo = jc.Planck15()
|
||||||
|
# pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
||||||
|
# dx, p, f = lpt(cosmo, pos, initial_conditions, a, sharding_info=sharding_info)
|
||||||
|
|
||||||
|
# #dx = 3*jax.random.normal(key=key, shape=[1048576, 3])
|
||||||
|
# # Initialize particles
|
||||||
|
# # pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
||||||
|
|
||||||
|
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
||||||
|
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||||
|
|
||||||
|
# # Recover the real space initial conditions
|
||||||
init_field, field = run_sim(cosmo, key)
|
init_field, field = run_sim(cosmo, key)
|
||||||
|
|
||||||
# Testing that the result is actually looking like what we expect
|
# import jaxdecomp
|
||||||
total_array, token = mpi4jax.allgather(field, comm=comms[0])
|
# field = jaxdecomp.halo_exchange(field,
|
||||||
total_array = total_array.reshape([N, N//2, N])
|
# halo_extents=sharding_info.halo_extents,
|
||||||
total_array, token = mpi4jax.allgather(
|
# halo_periods=(True,True,True),
|
||||||
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
|
# pdims=sharding_info.pdims,
|
||||||
total_array = total_array.reshape([N, N, N])
|
# global_shape=sharding_info.global_shape)
|
||||||
total_array = total_array.transpose([1, 0, 2])
|
|
||||||
|
|
||||||
if rank == 0:
|
# time1 = time.time()
|
||||||
onp.save('simulation.npy', total_array)
|
# init_field, field = run_sim(cosmo, key)
|
||||||
|
# init_field.block_until_ready()
|
||||||
|
# time2 = time.time()
|
||||||
|
|
||||||
print('Done !')
|
# if rank == 0:
|
||||||
|
onp.save('simulation_%d.npy'%rank, field)
|
||||||
|
|
||||||
|
# print('Done in', time2-time1)
|
||||||
|
|
Loading…
Add table
Reference in a new issue