temp commit

This commit is contained in:
Wassim KABALAN 2024-04-19 01:11:25 +02:00
parent 6ca4c9191e
commit 055ceedb7e
5 changed files with 220 additions and 110 deletions

View file

@ -17,13 +17,13 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
# ix = sharding_info[0].Get_rank()
# ny = sharding_info[1].Get_size()
# iy = sharding_info[1].Get_rank()
ix = sharding_info.rank
ix = sharding_info.rank % nx
iy = 0
shape = sharding_info.global_shape
for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d])
kd *= 2 * np.pi
kd = jnp.fft.fftfreq(shape[d])
kd *= 2 * jnp.pi
if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1]
@ -38,12 +38,8 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
return k
@partial(xmap,
in_axes=[['x', 'y', ...],
[['x'], ['y'], [...]]],
out_axes=['x', 'y', ...])
def apply_gradient_laplace(kfield, kvec):
kx, ky, kz = kvec
@jax.jit
def apply_gradient_laplace(kfield, kx, ky, kz):
kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),

View file

@ -1,10 +1,10 @@
# Module for custom ops, typically mpi4jax
import jax
import jax.numpy as jnp
import mpi4jax
import jaxdecomp
from dataclasses import dataclass
from typing import Tuple
from functools import partial
@dataclass
class ShardingInfo:
@ -21,22 +21,16 @@ def fft3d(arr, sharding_info=None):
if sharding_info is None:
arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
else:
arr = jaxdecomp.pfft3d(arr,
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
arr = jaxdecomp.pfft3d(arr)
return arr
def ifft3d(arr, sharding_info=None):
if sharding_info is None:
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
else:
arr = jaxdecomp.pifft3d(arr,
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
arr = jaxdecomp.pifft3d(arr)
return arr
def halo_reduce(arr, sharding_info=None):
if sharding_info is None:
return arr
@ -44,11 +38,7 @@ def halo_reduce(arr, sharding_info=None):
global_shape = sharding_info.global_shape
arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True),
pdims=sharding_info.pdims,
global_shape=(global_shape[0]+2*halo_size,
global_shape[1]+halo_size,
global_shape[2]))
halo_periods=(True,True,True))
# Apply correction along x
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
@ -70,7 +60,6 @@ def meshgrid3d(shape, sharding_info=None):
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, sharding_info=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.

View file

@ -4,86 +4,124 @@ import jax.lax as lax
from jaxpm.ops import halo_reduce
from jaxpm.kernels import fftk, cic_compensation
import jaxdecomp
from functools import partial
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from jax.experimental.shard_map import shard_map
def cic_paint(mesh, positions, halo_size=0, sharding_info=None):
@partial(jax.jit,static_argnums=(1))
def add_halo(positions , halo_size):
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
return positions
def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
print(f" positions {positions.shape}")
if sharding_info is not None:
# Add some padding for the halo exchange
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
[0, 0]])
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
out_specs=P('z', 'y'))
def sharded_pad(arr):
padded = jnp.pad(arr,pad_width=((halo_size, halo_size), (halo_size, halo_size), (0, 0)))
return padded
# Add some padding for the halo exchange
with gpu_mesh:
nbody_mesh = sharded_pad(nbody_mesh)
positions = add_halo(positions , halo_size)
with gpu_mesh:
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
with gpu_mesh:
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.reshape(
[-1, 8, 3]).astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(neighboor_coords.reshape(
[-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords,
kernel.reshape([-1, 8]),
dnums)
with gpu_mesh:
nbody_mesh = lax.scatter_add(nbody_mesh,
neighboor_coords,
kernel.reshape([-1, 8]),
dnums)
if sharding_info == None:
return mesh
return nbody_mesh
else:
mesh = halo_reduce(mesh, sharding_info)
return mesh[halo_size:-halo_size, halo_size:-halo_size]
with gpu_mesh :
nbody_mesh = halo_reduce(nbody_mesh, sharding_info)
nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size]
return nbody_mesh
@jax.jit
def reduce_and_sum(mesh,neighboor_coords,kernel):
return (mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
def cic_read(mesh, positions, halo_size=0, sharding_info=None):
def cic_read(gpu_mesh , mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
@partial(shard_map, mesh=gpu_mesh, in_specs=(P('z', 'y'),P()),
out_specs=P('z', 'y'))
def sharded_pad(arr , padding_width):
return jnp.pad(arr,pad_width=padding_width)
if sharding_info is not None:
# Add some padding and perfom hao exchange to retrieve
# neighboring regions
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
[0, 0]])
# mesh = halo_reduce(mesh, sharding_info)
import jaxdecomp
mesh = jaxdecomp.halo_exchange(mesh,
halo_extents=sharding_info.halo_extents,
halo_periods=(True,True,True),
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
with gpu_mesh:
padding_width = jnp.array([(halo_size, halo_size), (halo_size, halo_size), (0, 0)])
#mesh = sharded_pad(mesh,padding_width)
mesh = jaxdecomp.halo_exchange(mesh,
halo_extents=sharding_info.halo_extents,
halo_periods=(True,True,True))
positions = add_halo(positions , halo_size)
with gpu_mesh:
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
with gpu_mesh:
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(
neighboor_coords.astype('int32'), jnp.array(mesh.shape))
neighboor_coords = jnp.mod(
neighboor_coords.astype('int32'), jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
reduced = reduce_and_sum(mesh,neighboor_coords,kernel)
return reduced
def cic_paint_2d(mesh, positions, weight):

View file

@ -8,9 +8,14 @@ from jaxpm.ops import fft3d, ifft3d, zeros, normal
from jaxpm.kernels import fftk, apply_gradient_laplace
from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from jax.experimental.shard_map import shard_map
from functools import partial
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
"""
Computes gravitational forces on particles using a PM scheme
"""
@ -22,38 +27,88 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_in
# Computes gravitational forces
kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info)
forces_k = apply_gradient_laplace(delta_k, kvec)
# Interpolate forces at the position of particles
return jnp.stack([cic_read(ifft3d(forces_k[..., i], sharding_info=sharding_info).real,
positions, halo_size=halo_size, sharding_info=sharding_info)
for i in range(3)], axis=-1)
local_kx = kvec[0]
local_ky = kvec[1]
replicated_kz = kvec[2]
gspmd_kx = multihost_utils.host_local_array_to_global_array(local_kx ,mesh, P('z'))
gspmd_ky = multihost_utils.host_local_array_to_global_array(local_ky ,mesh, P('y'))
@partial(jax.jit,static_argnums=(1))
def ifft3d_c2r(forces_k , i):
return ifft3d(forces_k[..., i], sharding_info=sharding_info).real
forces = []
with mesh:
forces_k = apply_gradient_laplace(delta_k, gspmd_kx , gspmd_ky , replicated_kz)
# Interpolate forces at the position of particles
for i in range(3):
with mesh:
ifft_forces = ifft3d_c2r(forces_k , i)
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
forces.append(force)
print(f"Shape {ifft_forces.shape}")
return jnp.stack(forces)
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
"""
Computes first order LPT displacement
"""
initial_force = pm_forces(
initial_force = pm_forces(mesh,
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * \
jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
dGfa(cosmo, a) * initial_force
print(f"Shape initial {initial_conditions.shape}")
@jax.jit
def compute_dx(cosmo , i_force):
return growth_factor(cosmo, a) * i_force
@jax.jit
def compute_p(cosmo , dx):
return a**2 * growth_rate(cosmo, a) * \
jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
@jax.jit
def compute_f(cosmo , initial_force):
return a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
dGfa(cosmo, a) * initial_force
with mesh:
dx = compute_dx(cosmo , initial_force)
p = compute_p(cosmo , dx)
f = compute_f(cosmo , initial_force)
return dx, p, f
def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None):
@jax.jit
def interpolate(kfield, kx, ky, kz , k , pk):
return kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2), k, jnp.sqrt(pk))
def linear_field(cosmo, mesh, mesh_shape, box_size, key, sharding_info=None):
"""
Generate initial conditions in Fourier space.
"""
# Sample normal field
field = normal(key, mesh_shape, sharding_info=sharding_info)
pdims = sharding_info.pdims
slice_shape = (mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0],mesh_shape[2])
slice_field = normal(key, slice_shape, sharding_info=sharding_info)
field = multihost_utils.host_local_array_to_global_array(
slice_field, mesh, P('z', 'y'))
# Transform to Fourier space
kfield = fft3d(field, sharding_info=sharding_info)
with mesh :
kfield = fft3d(field, sharding_info=sharding_info)
# Rescaling k to physical units
kvec = [k / box_size[i] * mesh_shape[i]
@ -68,11 +123,17 @@ def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None):
) / (box_size[0] * box_size[1] * box_size[2])
# Multipliyng the field by the proper power spectrum
kfield = xmap(lambda kfield, kx, ky, kz:
kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2),
k, jnp.sqrt(pk)),
in_axes=(('x', 'y', ...), ['x'], ['y'], [...]),
out_axes=('x', 'y', ...))(kfield, kvec[0], kvec[1], kvec[2])
local_kx = kvec[0]
local_ky = kvec[1]
replicated_kz = kvec[2]
gspmd_kx = multihost_utils.host_local_array_to_global_array(local_kx ,mesh, P('z'))
gspmd_ky = multihost_utils.host_local_array_to_global_array(local_ky ,mesh, P('y'))
with mesh:
kfield = interpolate(kfield,gspmd_kx, gspmd_ky, replicated_kz ,k, pk)
return kfield

View file

@ -10,19 +10,20 @@ from jaxpm.painting import cic_paint
from jax.experimental.ode import odeint
import jax_cosmo as jc
import time
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from functools import partial
### Setting up a whole bunch of things #######
# Create communicators
world = MPI.COMM_WORLD
rank = world.Get_rank()
size = world.Get_size()
jax.config.update("jax_enable_x64", True)
# Here we assume clients are on the same node, so we restrict which device
# they can use based on their rank
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1)
jaxdecomp.init()
jax.distributed.initialize()
# Setup random keys
master_key = jax.random.PRNGKey(42)
@ -35,37 +36,57 @@ mesh_shape = (N, N, N)
box_size = [500, 500, 500] # Mpc/h
halo_size = 32
sharding_info = ShardingInfo(global_shape=mesh_shape,
pdims=(1,2),
pdims=(2,2),
halo_extents=(halo_size, halo_size, 0),
rank=rank)
cosmo = jc.Planck15()
a = 0.1
devices = mesh_utils.create_device_mesh(sharding_info.pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
sharding_info=sharding_info)
@jax.jit
def run_sim(cosmo, key):
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
sharding_info=sharding_info)
def ifft3d_c2r(initial_conditions):
return ifft3d(initial_conditions, sharding_info=sharding_info).real
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real
@jax.jit
def compute_displacement(p , dx):
return p + dx
# Initialize particles
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
def run_sim(mesh , initial_conditions, cosmo, key):
with mesh:
init_field = ifft3d_c2r(initial_conditions)
# Initialize particles
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# Initial displacement by LPT
cosmo = jc.Planck15()
dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
dx, p, f = lpt(mesh , cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
# And now, we run an actual nbody
res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
rtol=1e-3, atol=1e-3)
# Painting on a new mesh
field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
res[0][-1], halo_size, sharding_info=sharding_info)
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
# pos+dx, halo_size, sharding_info=sharding_info)
#res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
# rtol=1e-3, atol=1e-3)
## Painting on a new mesh
print(f"shape of p {p.shape}")
print(f"shape of dx {dx.shape}")
with mesh:
displacement = compute_displacement(p , dx)
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
field = cic_paint(mesh , empty_field,
displacement, halo_size, sharding_info=sharding_info)
return init_field, field
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
@ -87,7 +108,8 @@ def run_sim(cosmo, key):
# pos+dx, halo_size, sharding_info=sharding_info)
# # Recover the real space initial conditions
init_field, field = run_sim(cosmo, key)
run_sim(mesh , initial_conditions,cosmo, key)
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
# import jaxdecomp
# field = jaxdecomp.halo_exchange(field,
@ -102,6 +124,10 @@ init_field, field = run_sim(cosmo, key)
# time2 = time.time()
# if rank == 0:
onp.save('simulation_%d.npy'%rank, field)
#onp.save('simulation_%d.npy'%rank, field)
# print('Done in', time2-time1)
print("Done")
jaxdecomp.finalize()