This commit is contained in:
Wassim KABALAN 2024-04-19 10:32:38 +02:00
parent 055ceedb7e
commit 179030377b
4 changed files with 63 additions and 38 deletions

View file

@ -5,7 +5,10 @@ import jaxdecomp
from dataclasses import dataclass
from typing import Tuple
from functools import partial
from jax import jit
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from jax.experimental.shard_map import shard_map
@dataclass
class ShardingInfo:
"""Class for keeping track of the distribution strategy"""
@ -31,22 +34,41 @@ def ifft3d(arr, sharding_info=None):
arr = jaxdecomp.pifft3d(arr)
return arr
def halo_reduce(arr, sharding_info=None):
if sharding_info is None:
def halo_reduce(arr, halo_size , gpu_mesh):
with gpu_mesh:
arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True))
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
def apply_correction_x(arr):
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
return arr
halo_size = sharding_info.halo_extents[0]
global_shape = sharding_info.global_shape
arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True))
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
def apply_correction_y(arr):
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
return arr
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
def un_pad(arr):
return arr[halo_size:-halo_size, halo_size:-halo_size]
# Apply correction along x
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
arr = apply_correction_x(arr)
# Apply correction along y
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
arr = apply_correction_y(arr)
arr = un_pad(arr)
return arr
@ -60,14 +82,18 @@ def meshgrid3d(shape, sharding_info=None):
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, sharding_info=None):
def zeros(mesh , shape, sharding_info=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.
"""
if sharding_info is None:
return jnp.zeros(shape)
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
zeros_slice = jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], \
sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
gspmd_zeros = multihost_utils.host_local_array_to_global_array(zeros_slice ,mesh, P('z' , 'y'))
return gspmd_zeros
def normal(key, shape, sharding_info=None):

View file

@ -1,4 +1,5 @@
import jax
from jax import jit
import jax.numpy as jnp
import jax.lax as lax
@ -22,7 +23,6 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
print(f" positions {positions.shape}")
if sharding_info is not None:
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
@ -34,20 +34,25 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
# Add some padding for the halo exchange
with gpu_mesh:
nbody_mesh = sharded_pad(nbody_mesh)
positions = add_halo(positions , halo_size)
with gpu_mesh:
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
floor = jit(jnp.floor)(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
@jit
def compute_kernels(positions , neighboor_coords):
kernel = (1. - jnp.abs(positions - neighboor_coords))
return (kernel[..., 0] * kernel[..., 1] * kernel[..., 2])
with gpu_mesh:
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jit(jnp.add)(floor , connection)
kernel = compute_kernels(positions , neighboor_coords)
neighboor_coords = jnp.mod(neighboor_coords.reshape(
[-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
@ -66,9 +71,7 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
if sharding_info == None:
return nbody_mesh
else:
with gpu_mesh :
nbody_mesh = halo_reduce(nbody_mesh, sharding_info)
nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size]
nbody_mesh = halo_reduce(nbody_mesh, sharding_info.halo_extents[0] , gpu_mesh)
return nbody_mesh

View file

@ -50,9 +50,8 @@ def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, shar
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
forces.append(force)
print(f"Shape {ifft_forces.shape}")
return jnp.stack(forces)
return jnp.stack(forces , axis=-1)
@ -64,7 +63,6 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
a = jnp.atleast_1d(a)
print(f"Shape initial {initial_conditions.shape}")
@jax.jit
def compute_dx(cosmo , i_force):
@ -85,6 +83,8 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
p = compute_p(cosmo , dx)
f = compute_f(cosmo , initial_force)
return dx, p, f

View file

@ -1,6 +1,7 @@
from mpi4py import MPI
import os
import jax
from jax import jit
import jax.numpy as jnp
import numpy as onp
import jaxdecomp
@ -53,12 +54,6 @@ initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
def ifft3d_c2r(initial_conditions):
return ifft3d(initial_conditions, sharding_info=sharding_info).real
@jax.jit
def compute_displacement(p , dx):
return p + dx
def run_sim(mesh , initial_conditions, cosmo, key):
with mesh:
@ -77,12 +72,10 @@ def run_sim(mesh , initial_conditions, cosmo, key):
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
# rtol=1e-3, atol=1e-3)
## Painting on a new mesh
print(f"shape of p {p.shape}")
print(f"shape of dx {dx.shape}")
with mesh:
displacement = compute_displacement(p , dx)
displacement = jit(jnp.add)(p , dx)
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
empty_field = zeros(mesh , mesh_shape, sharding_info=sharding_info)
field = cic_paint(mesh , empty_field,
displacement, halo_size, sharding_info=sharding_info)
@ -108,7 +101,7 @@ def run_sim(mesh , initial_conditions, cosmo, key):
# pos+dx, halo_size, sharding_info=sharding_info)
# # Recover the real space initial conditions
run_sim(mesh , initial_conditions,cosmo, key)
init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
# import jaxdecomp
@ -123,8 +116,11 @@ run_sim(mesh , initial_conditions,cosmo, key)
# init_field.block_until_ready()
# time2 = time.time()
# if rank == 0:
#onp.save('simulation_%d.npy'%rank, field)
onp.save('simulation_init_field_float16_%d.npy'%rank, init_field.addressable_data(0).astype(onp.float16))
onp.save('simulation_field_float16_%d.npy'%rank, field.addressable_data(0).astype(onp.float16))
# print('Done in', time2-time1)