mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 19:50:55 +00:00
temp commit
This commit is contained in:
parent
6ca4c9191e
commit
055ceedb7e
5 changed files with 220 additions and 110 deletions
|
@ -17,13 +17,13 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
|
||||||
# ix = sharding_info[0].Get_rank()
|
# ix = sharding_info[0].Get_rank()
|
||||||
# ny = sharding_info[1].Get_size()
|
# ny = sharding_info[1].Get_size()
|
||||||
# iy = sharding_info[1].Get_rank()
|
# iy = sharding_info[1].Get_rank()
|
||||||
ix = sharding_info.rank
|
ix = sharding_info.rank % nx
|
||||||
iy = 0
|
iy = 0
|
||||||
shape = sharding_info.global_shape
|
shape = sharding_info.global_shape
|
||||||
|
|
||||||
for d in range(len(shape)):
|
for d in range(len(shape)):
|
||||||
kd = np.fft.fftfreq(shape[d])
|
kd = jnp.fft.fftfreq(shape[d])
|
||||||
kd *= 2 * np.pi
|
kd *= 2 * jnp.pi
|
||||||
|
|
||||||
if symmetric and d == len(shape) - 1:
|
if symmetric and d == len(shape) - 1:
|
||||||
kd = kd[:shape[d] // 2 + 1]
|
kd = kd[:shape[d] // 2 + 1]
|
||||||
|
@ -38,12 +38,8 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
|
||||||
return k
|
return k
|
||||||
|
|
||||||
|
|
||||||
@partial(xmap,
|
@jax.jit
|
||||||
in_axes=[['x', 'y', ...],
|
def apply_gradient_laplace(kfield, kx, ky, kz):
|
||||||
[['x'], ['y'], [...]]],
|
|
||||||
out_axes=['x', 'y', ...])
|
|
||||||
def apply_gradient_laplace(kfield, kvec):
|
|
||||||
kx, ky, kz = kvec
|
|
||||||
kk = (kx**2 + ky**2 + kz**2)
|
kk = (kx**2 + ky**2 + kz**2)
|
||||||
kernel = jnp.where(kk == 0, 1., 1./kk)
|
kernel = jnp.where(kk == 0, 1., 1./kk)
|
||||||
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),
|
||||||
|
|
19
jaxpm/ops.py
19
jaxpm/ops.py
|
@ -1,10 +1,10 @@
|
||||||
# Module for custom ops, typically mpi4jax
|
# Module for custom ops, typically mpi4jax
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import mpi4jax
|
|
||||||
import jaxdecomp
|
import jaxdecomp
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ShardingInfo:
|
class ShardingInfo:
|
||||||
|
@ -21,22 +21,16 @@ def fft3d(arr, sharding_info=None):
|
||||||
if sharding_info is None:
|
if sharding_info is None:
|
||||||
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
|
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
|
||||||
else:
|
else:
|
||||||
arr = jaxdecomp.pfft3d(arr,
|
arr = jaxdecomp.pfft3d(arr)
|
||||||
pdims=sharding_info.pdims,
|
|
||||||
global_shape=sharding_info.global_shape)
|
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def ifft3d(arr, sharding_info=None):
|
def ifft3d(arr, sharding_info=None):
|
||||||
if sharding_info is None:
|
if sharding_info is None:
|
||||||
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
|
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
|
||||||
else:
|
else:
|
||||||
arr = jaxdecomp.pifft3d(arr,
|
arr = jaxdecomp.pifft3d(arr)
|
||||||
pdims=sharding_info.pdims,
|
|
||||||
global_shape=sharding_info.global_shape)
|
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def halo_reduce(arr, sharding_info=None):
|
def halo_reduce(arr, sharding_info=None):
|
||||||
if sharding_info is None:
|
if sharding_info is None:
|
||||||
return arr
|
return arr
|
||||||
|
@ -44,11 +38,7 @@ def halo_reduce(arr, sharding_info=None):
|
||||||
global_shape = sharding_info.global_shape
|
global_shape = sharding_info.global_shape
|
||||||
arr = jaxdecomp.halo_exchange(arr,
|
arr = jaxdecomp.halo_exchange(arr,
|
||||||
halo_extents=(halo_size//2, halo_size//2, 0),
|
halo_extents=(halo_size//2, halo_size//2, 0),
|
||||||
halo_periods=(True,True,True),
|
halo_periods=(True,True,True))
|
||||||
pdims=sharding_info.pdims,
|
|
||||||
global_shape=(global_shape[0]+2*halo_size,
|
|
||||||
global_shape[1]+halo_size,
|
|
||||||
global_shape[2]))
|
|
||||||
|
|
||||||
# Apply correction along x
|
# Apply correction along x
|
||||||
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
||||||
|
@ -70,7 +60,6 @@ def meshgrid3d(shape, sharding_info=None):
|
||||||
|
|
||||||
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
||||||
|
|
||||||
|
|
||||||
def zeros(shape, sharding_info=None):
|
def zeros(shape, sharding_info=None):
|
||||||
""" Initialize an array of given global shape
|
""" Initialize an array of given global shape
|
||||||
partitionned if need be accross dimensions.
|
partitionned if need be accross dimensions.
|
||||||
|
|
|
@ -4,76 +4,114 @@ import jax.lax as lax
|
||||||
|
|
||||||
from jaxpm.ops import halo_reduce
|
from jaxpm.ops import halo_reduce
|
||||||
from jaxpm.kernels import fftk, cic_compensation
|
from jaxpm.kernels import fftk, cic_compensation
|
||||||
|
import jaxdecomp
|
||||||
|
from functools import partial
|
||||||
|
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
|
||||||
|
|
||||||
def cic_paint(mesh, positions, halo_size=0, sharding_info=None):
|
@partial(jax.jit,static_argnums=(1))
|
||||||
|
def add_halo(positions , halo_size):
|
||||||
|
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
||||||
|
return positions
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
positions: [npart, 3]
|
||||||
"""
|
"""
|
||||||
|
print(f" positions {positions.shape}")
|
||||||
if sharding_info is not None:
|
if sharding_info is not None:
|
||||||
# Add some padding for the halo exchange
|
|
||||||
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
|
||||||
[halo_size, halo_size],
|
|
||||||
[0, 0]])
|
|
||||||
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
|
||||||
|
|
||||||
|
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
|
||||||
|
out_specs=P('z', 'y'))
|
||||||
|
def sharded_pad(arr):
|
||||||
|
padded = jnp.pad(arr,pad_width=((halo_size, halo_size), (halo_size, halo_size), (0, 0)))
|
||||||
|
return padded
|
||||||
|
|
||||||
|
# Add some padding for the halo exchange
|
||||||
|
with gpu_mesh:
|
||||||
|
nbody_mesh = sharded_pad(nbody_mesh)
|
||||||
|
positions = add_halo(positions , halo_size)
|
||||||
|
|
||||||
|
with gpu_mesh:
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
|
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||||
[0., 1, 1], [1., 1, 1]]])
|
[0., 1, 1], [1., 1, 1]]])
|
||||||
|
|
||||||
|
with gpu_mesh:
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.reshape(
|
neighboor_coords = jnp.mod(neighboor_coords.reshape(
|
||||||
[-1, 8, 3]).astype('int32'), jnp.array(mesh.shape))
|
[-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
|
||||||
|
|
||||||
dnums = jax.lax.ScatterDimensionNumbers(
|
dnums = jax.lax.ScatterDimensionNumbers(
|
||||||
update_window_dims=(),
|
update_window_dims=(),
|
||||||
inserted_window_dims=(0, 1, 2),
|
inserted_window_dims=(0, 1, 2),
|
||||||
scatter_dims_to_operand_dims=(0, 1, 2))
|
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||||
mesh = lax.scatter_add(mesh,
|
|
||||||
|
with gpu_mesh:
|
||||||
|
nbody_mesh = lax.scatter_add(nbody_mesh,
|
||||||
neighboor_coords,
|
neighboor_coords,
|
||||||
kernel.reshape([-1, 8]),
|
kernel.reshape([-1, 8]),
|
||||||
dnums)
|
dnums)
|
||||||
|
|
||||||
if sharding_info == None:
|
if sharding_info == None:
|
||||||
return mesh
|
return nbody_mesh
|
||||||
else:
|
else:
|
||||||
mesh = halo_reduce(mesh, sharding_info)
|
with gpu_mesh :
|
||||||
return mesh[halo_size:-halo_size, halo_size:-halo_size]
|
nbody_mesh = halo_reduce(nbody_mesh, sharding_info)
|
||||||
|
nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||||
|
|
||||||
|
return nbody_mesh
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def reduce_and_sum(mesh,neighboor_coords,kernel):
|
||||||
|
return (mesh[neighboor_coords[..., 0],
|
||||||
|
neighboor_coords[..., 1],
|
||||||
|
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
|
||||||
|
|
||||||
|
|
||||||
def cic_read(mesh, positions, halo_size=0, sharding_info=None):
|
|
||||||
|
def cic_read(gpu_mesh , mesh, positions, halo_size=0, sharding_info=None):
|
||||||
""" Paints positions onto mesh
|
""" Paints positions onto mesh
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
positions: [npart, 3]
|
||||||
"""
|
"""
|
||||||
|
@partial(shard_map, mesh=gpu_mesh, in_specs=(P('z', 'y'),P()),
|
||||||
|
out_specs=P('z', 'y'))
|
||||||
|
def sharded_pad(arr , padding_width):
|
||||||
|
return jnp.pad(arr,pad_width=padding_width)
|
||||||
|
|
||||||
if sharding_info is not None:
|
if sharding_info is not None:
|
||||||
# Add some padding and perfom hao exchange to retrieve
|
# Add some padding and perfom hao exchange to retrieve
|
||||||
# neighboring regions
|
# neighboring regions
|
||||||
mesh = jnp.pad(mesh, [[halo_size, halo_size],
|
|
||||||
[halo_size, halo_size],
|
|
||||||
[0, 0]])
|
|
||||||
# mesh = halo_reduce(mesh, sharding_info)
|
# mesh = halo_reduce(mesh, sharding_info)
|
||||||
import jaxdecomp
|
with gpu_mesh:
|
||||||
|
padding_width = jnp.array([(halo_size, halo_size), (halo_size, halo_size), (0, 0)])
|
||||||
|
#mesh = sharded_pad(mesh,padding_width)
|
||||||
mesh = jaxdecomp.halo_exchange(mesh,
|
mesh = jaxdecomp.halo_exchange(mesh,
|
||||||
halo_extents=sharding_info.halo_extents,
|
halo_extents=sharding_info.halo_extents,
|
||||||
halo_periods=(True,True,True),
|
halo_periods=(True,True,True))
|
||||||
pdims=sharding_info.pdims,
|
positions = add_halo(positions , halo_size)
|
||||||
global_shape=sharding_info.global_shape)
|
|
||||||
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
|
|
||||||
|
|
||||||
|
with gpu_mesh:
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jnp.floor(positions)
|
||||||
|
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||||
[0., 1, 1], [1., 1, 1]]])
|
[0., 1, 1], [1., 1, 1]]])
|
||||||
|
|
||||||
|
with gpu_mesh:
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = floor + connection
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||||
|
@ -81,9 +119,9 @@ def cic_read(mesh, positions, halo_size=0, sharding_info=None):
|
||||||
neighboor_coords = jnp.mod(
|
neighboor_coords = jnp.mod(
|
||||||
neighboor_coords.astype('int32'), jnp.array(mesh.shape))
|
neighboor_coords.astype('int32'), jnp.array(mesh.shape))
|
||||||
|
|
||||||
return (mesh[neighboor_coords[..., 0],
|
reduced = reduce_and_sum(mesh,neighboor_coords,kernel)
|
||||||
neighboor_coords[..., 1],
|
|
||||||
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
|
return reduced
|
||||||
|
|
||||||
|
|
||||||
def cic_paint_2d(mesh, positions, weight):
|
def cic_paint_2d(mesh, positions, weight):
|
||||||
|
|
95
jaxpm/pm.py
95
jaxpm/pm.py
|
@ -8,9 +8,14 @@ from jaxpm.ops import fft3d, ifft3d, zeros, normal
|
||||||
from jaxpm.kernels import fftk, apply_gradient_laplace
|
from jaxpm.kernels import fftk, apply_gradient_laplace
|
||||||
from jaxpm.painting import cic_paint, cic_read
|
from jaxpm.painting import cic_paint, cic_read
|
||||||
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
from jaxpm.growth import growth_factor, growth_rate, dGfa
|
||||||
|
from jax.experimental import mesh_utils, multihost_utils
|
||||||
|
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
|
|
||||||
|
def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
|
||||||
"""
|
"""
|
||||||
Computes gravitational forces on particles using a PM scheme
|
Computes gravitational forces on particles using a PM scheme
|
||||||
"""
|
"""
|
||||||
|
@ -22,37 +27,87 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_in
|
||||||
|
|
||||||
# Computes gravitational forces
|
# Computes gravitational forces
|
||||||
kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info)
|
kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info)
|
||||||
forces_k = apply_gradient_laplace(delta_k, kvec)
|
|
||||||
|
|
||||||
|
local_kx = kvec[0]
|
||||||
|
local_ky = kvec[1]
|
||||||
|
replicated_kz = kvec[2]
|
||||||
|
|
||||||
|
gspmd_kx = multihost_utils.host_local_array_to_global_array(local_kx ,mesh, P('z'))
|
||||||
|
gspmd_ky = multihost_utils.host_local_array_to_global_array(local_ky ,mesh, P('y'))
|
||||||
|
|
||||||
|
@partial(jax.jit,static_argnums=(1))
|
||||||
|
def ifft3d_c2r(forces_k , i):
|
||||||
|
return ifft3d(forces_k[..., i], sharding_info=sharding_info).real
|
||||||
|
|
||||||
|
forces = []
|
||||||
|
with mesh:
|
||||||
|
forces_k = apply_gradient_laplace(delta_k, gspmd_kx , gspmd_ky , replicated_kz)
|
||||||
# Interpolate forces at the position of particles
|
# Interpolate forces at the position of particles
|
||||||
return jnp.stack([cic_read(ifft3d(forces_k[..., i], sharding_info=sharding_info).real,
|
|
||||||
positions, halo_size=halo_size, sharding_info=sharding_info)
|
for i in range(3):
|
||||||
for i in range(3)], axis=-1)
|
with mesh:
|
||||||
|
ifft_forces = ifft3d_c2r(forces_k , i)
|
||||||
|
|
||||||
|
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
|
||||||
|
forces.append(force)
|
||||||
|
print(f"Shape {ifft_forces.shape}")
|
||||||
|
|
||||||
|
return jnp.stack(forces)
|
||||||
|
|
||||||
|
|
||||||
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
|
|
||||||
|
def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
|
||||||
"""
|
"""
|
||||||
Computes first order LPT displacement
|
Computes first order LPT displacement
|
||||||
"""
|
"""
|
||||||
initial_force = pm_forces(
|
initial_force = pm_forces(mesh,
|
||||||
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
|
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
dx = growth_factor(cosmo, a) * initial_force
|
|
||||||
p = a**2 * growth_rate(cosmo, a) * \
|
print(f"Shape initial {initial_conditions.shape}")
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def compute_dx(cosmo , i_force):
|
||||||
|
return growth_factor(cosmo, a) * i_force
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def compute_p(cosmo , dx):
|
||||||
|
return a**2 * growth_rate(cosmo, a) * \
|
||||||
jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
|
jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
|
||||||
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
|
|
||||||
|
@jax.jit
|
||||||
|
def compute_f(cosmo , initial_force):
|
||||||
|
return a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
|
||||||
dGfa(cosmo, a) * initial_force
|
dGfa(cosmo, a) * initial_force
|
||||||
|
|
||||||
|
with mesh:
|
||||||
|
dx = compute_dx(cosmo , initial_force)
|
||||||
|
p = compute_p(cosmo , dx)
|
||||||
|
f = compute_f(cosmo , initial_force)
|
||||||
|
|
||||||
return dx, p, f
|
return dx, p, f
|
||||||
|
|
||||||
|
|
||||||
def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None):
|
@jax.jit
|
||||||
|
def interpolate(kfield, kx, ky, kz , k , pk):
|
||||||
|
return kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2), k, jnp.sqrt(pk))
|
||||||
|
|
||||||
|
|
||||||
|
def linear_field(cosmo, mesh, mesh_shape, box_size, key, sharding_info=None):
|
||||||
"""
|
"""
|
||||||
Generate initial conditions in Fourier space.
|
Generate initial conditions in Fourier space.
|
||||||
"""
|
"""
|
||||||
# Sample normal field
|
# Sample normal field
|
||||||
field = normal(key, mesh_shape, sharding_info=sharding_info)
|
pdims = sharding_info.pdims
|
||||||
|
slice_shape = (mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0],mesh_shape[2])
|
||||||
|
|
||||||
|
slice_field = normal(key, slice_shape, sharding_info=sharding_info)
|
||||||
|
|
||||||
|
field = multihost_utils.host_local_array_to_global_array(
|
||||||
|
slice_field, mesh, P('z', 'y'))
|
||||||
|
|
||||||
# Transform to Fourier space
|
# Transform to Fourier space
|
||||||
|
with mesh :
|
||||||
kfield = fft3d(field, sharding_info=sharding_info)
|
kfield = fft3d(field, sharding_info=sharding_info)
|
||||||
|
|
||||||
# Rescaling k to physical units
|
# Rescaling k to physical units
|
||||||
|
@ -68,11 +123,17 @@ def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None):
|
||||||
) / (box_size[0] * box_size[1] * box_size[2])
|
) / (box_size[0] * box_size[1] * box_size[2])
|
||||||
|
|
||||||
# Multipliyng the field by the proper power spectrum
|
# Multipliyng the field by the proper power spectrum
|
||||||
kfield = xmap(lambda kfield, kx, ky, kz:
|
|
||||||
kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2),
|
local_kx = kvec[0]
|
||||||
k, jnp.sqrt(pk)),
|
local_ky = kvec[1]
|
||||||
in_axes=(('x', 'y', ...), ['x'], ['y'], [...]),
|
replicated_kz = kvec[2]
|
||||||
out_axes=('x', 'y', ...))(kfield, kvec[0], kvec[1], kvec[2])
|
|
||||||
|
gspmd_kx = multihost_utils.host_local_array_to_global_array(local_kx ,mesh, P('z'))
|
||||||
|
gspmd_ky = multihost_utils.host_local_array_to_global_array(local_ky ,mesh, P('y'))
|
||||||
|
|
||||||
|
|
||||||
|
with mesh:
|
||||||
|
kfield = interpolate(kfield,gspmd_kx, gspmd_ky, replicated_kz ,k, pk)
|
||||||
|
|
||||||
return kfield
|
return kfield
|
||||||
|
|
||||||
|
|
|
@ -10,19 +10,20 @@ from jaxpm.painting import cic_paint
|
||||||
from jax.experimental.ode import odeint
|
from jax.experimental.ode import odeint
|
||||||
import jax_cosmo as jc
|
import jax_cosmo as jc
|
||||||
import time
|
import time
|
||||||
|
from jax.experimental import mesh_utils, multihost_utils
|
||||||
|
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
|
||||||
|
from functools import partial
|
||||||
### Setting up a whole bunch of things #######
|
### Setting up a whole bunch of things #######
|
||||||
# Create communicators
|
# Create communicators
|
||||||
world = MPI.COMM_WORLD
|
world = MPI.COMM_WORLD
|
||||||
rank = world.Get_rank()
|
rank = world.Get_rank()
|
||||||
size = world.Get_size()
|
size = world.Get_size()
|
||||||
|
|
||||||
|
jax.config.update("jax_enable_x64", True)
|
||||||
|
|
||||||
# Here we assume clients are on the same node, so we restrict which device
|
# Here we assume clients are on the same node, so we restrict which device
|
||||||
# they can use based on their rank
|
# they can use based on their rank
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1)
|
jax.distributed.initialize()
|
||||||
|
|
||||||
|
|
||||||
jaxdecomp.init()
|
|
||||||
|
|
||||||
# Setup random keys
|
# Setup random keys
|
||||||
master_key = jax.random.PRNGKey(42)
|
master_key = jax.random.PRNGKey(42)
|
||||||
|
@ -35,37 +36,57 @@ mesh_shape = (N, N, N)
|
||||||
box_size = [500, 500, 500] # Mpc/h
|
box_size = [500, 500, 500] # Mpc/h
|
||||||
halo_size = 32
|
halo_size = 32
|
||||||
sharding_info = ShardingInfo(global_shape=mesh_shape,
|
sharding_info = ShardingInfo(global_shape=mesh_shape,
|
||||||
pdims=(1,2),
|
pdims=(2,2),
|
||||||
halo_extents=(halo_size, halo_size, 0),
|
halo_extents=(halo_size, halo_size, 0),
|
||||||
rank=rank)
|
rank=rank)
|
||||||
cosmo = jc.Planck15()
|
cosmo = jc.Planck15()
|
||||||
a = 0.1
|
a = 0.1
|
||||||
|
|
||||||
|
|
||||||
@jax.jit
|
devices = mesh_utils.create_device_mesh(sharding_info.pdims[::-1])
|
||||||
def run_sim(cosmo, key):
|
mesh = Mesh(devices, axis_names=('z', 'y'))
|
||||||
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
|
||||||
|
initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
|
||||||
sharding_info=sharding_info)
|
sharding_info=sharding_info)
|
||||||
|
|
||||||
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
|
@jax.jit
|
||||||
|
def ifft3d_c2r(initial_conditions):
|
||||||
|
return ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def compute_displacement(p , dx):
|
||||||
|
return p + dx
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_sim(mesh , initial_conditions, cosmo, key):
|
||||||
|
|
||||||
|
with mesh:
|
||||||
|
init_field = ifft3d_c2r(initial_conditions)
|
||||||
|
|
||||||
# Initialize particles
|
# Initialize particles
|
||||||
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
|
||||||
|
|
||||||
# Initial displacement by LPT
|
# Initial displacement by LPT
|
||||||
cosmo = jc.Planck15()
|
cosmo = jc.Planck15()
|
||||||
dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
|
|
||||||
|
dx, p, f = lpt(mesh , cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
|
||||||
|
|
||||||
# And now, we run an actual nbody
|
# And now, we run an actual nbody
|
||||||
res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
|
#res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
|
||||||
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||||
rtol=1e-3, atol=1e-3)
|
# rtol=1e-3, atol=1e-3)
|
||||||
# Painting on a new mesh
|
## Painting on a new mesh
|
||||||
field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
print(f"shape of p {p.shape}")
|
||||||
res[0][-1], halo_size, sharding_info=sharding_info)
|
print(f"shape of dx {dx.shape}")
|
||||||
|
with mesh:
|
||||||
|
displacement = compute_displacement(p , dx)
|
||||||
|
|
||||||
|
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
|
||||||
|
|
||||||
|
field = cic_paint(mesh , empty_field,
|
||||||
|
displacement, halo_size, sharding_info=sharding_info)
|
||||||
|
|
||||||
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
|
|
||||||
# pos+dx, halo_size, sharding_info=sharding_info)
|
|
||||||
return init_field, field
|
return init_field, field
|
||||||
|
|
||||||
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
|
||||||
|
@ -87,7 +108,8 @@ def run_sim(cosmo, key):
|
||||||
# pos+dx, halo_size, sharding_info=sharding_info)
|
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||||
|
|
||||||
# # Recover the real space initial conditions
|
# # Recover the real space initial conditions
|
||||||
init_field, field = run_sim(cosmo, key)
|
run_sim(mesh , initial_conditions,cosmo, key)
|
||||||
|
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
|
||||||
|
|
||||||
# import jaxdecomp
|
# import jaxdecomp
|
||||||
# field = jaxdecomp.halo_exchange(field,
|
# field = jaxdecomp.halo_exchange(field,
|
||||||
|
@ -102,6 +124,10 @@ init_field, field = run_sim(cosmo, key)
|
||||||
# time2 = time.time()
|
# time2 = time.time()
|
||||||
|
|
||||||
# if rank == 0:
|
# if rank == 0:
|
||||||
onp.save('simulation_%d.npy'%rank, field)
|
#onp.save('simulation_%d.npy'%rank, field)
|
||||||
|
|
||||||
# print('Done in', time2-time1)
|
# print('Done in', time2-time1)
|
||||||
|
|
||||||
|
print("Done")
|
||||||
|
|
||||||
|
jaxdecomp.finalize()
|
Loading…
Add table
Reference in a new issue