temp commit

This commit is contained in:
Wassim KABALAN 2024-04-19 01:11:25 +02:00
parent 6ca4c9191e
commit 055ceedb7e
5 changed files with 220 additions and 110 deletions

View file

@ -17,13 +17,13 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
# ix = sharding_info[0].Get_rank() # ix = sharding_info[0].Get_rank()
# ny = sharding_info[1].Get_size() # ny = sharding_info[1].Get_size()
# iy = sharding_info[1].Get_rank() # iy = sharding_info[1].Get_rank()
ix = sharding_info.rank ix = sharding_info.rank % nx
iy = 0 iy = 0
shape = sharding_info.global_shape shape = sharding_info.global_shape
for d in range(len(shape)): for d in range(len(shape)):
kd = np.fft.fftfreq(shape[d]) kd = jnp.fft.fftfreq(shape[d])
kd *= 2 * np.pi kd *= 2 * jnp.pi
if symmetric and d == len(shape) - 1: if symmetric and d == len(shape) - 1:
kd = kd[:shape[d] // 2 + 1] kd = kd[:shape[d] // 2 + 1]
@ -38,12 +38,8 @@ def fftk(shape, symmetric=False, dtype=np.float32, sharding_info=None):
return k return k
@partial(xmap, @jax.jit
in_axes=[['x', 'y', ...], def apply_gradient_laplace(kfield, kx, ky, kz):
[['x'], ['y'], [...]]],
out_axes=['x', 'y', ...])
def apply_gradient_laplace(kfield, kvec):
kx, ky, kz = kvec
kk = (kx**2 + ky**2 + kz**2) kk = (kx**2 + ky**2 + kz**2)
kernel = jnp.where(kk == 0, 1., 1./kk) kernel = jnp.where(kk == 0, 1., 1./kk)
return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)), return jnp.stack([kfield * kernel * 1j * 1 / 6.0 * (8 * jnp.sin(kz) - jnp.sin(2 * kz)),

View file

@ -1,10 +1,10 @@
# Module for custom ops, typically mpi4jax # Module for custom ops, typically mpi4jax
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import mpi4jax
import jaxdecomp import jaxdecomp
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple from typing import Tuple
from functools import partial
@dataclass @dataclass
class ShardingInfo: class ShardingInfo:
@ -21,22 +21,16 @@ def fft3d(arr, sharding_info=None):
if sharding_info is None: if sharding_info is None:
arr = jnp.fft.fftn(arr).transpose([1, 2, 0]) arr = jnp.fft.fftn(arr).transpose([1, 2, 0])
else: else:
arr = jaxdecomp.pfft3d(arr, arr = jaxdecomp.pfft3d(arr)
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
return arr return arr
def ifft3d(arr, sharding_info=None): def ifft3d(arr, sharding_info=None):
if sharding_info is None: if sharding_info is None:
arr = jnp.fft.ifftn(arr.transpose([2, 0, 1])) arr = jnp.fft.ifftn(arr.transpose([2, 0, 1]))
else: else:
arr = jaxdecomp.pifft3d(arr, arr = jaxdecomp.pifft3d(arr)
pdims=sharding_info.pdims,
global_shape=sharding_info.global_shape)
return arr return arr
def halo_reduce(arr, sharding_info=None): def halo_reduce(arr, sharding_info=None):
if sharding_info is None: if sharding_info is None:
return arr return arr
@ -44,11 +38,7 @@ def halo_reduce(arr, sharding_info=None):
global_shape = sharding_info.global_shape global_shape = sharding_info.global_shape
arr = jaxdecomp.halo_exchange(arr, arr = jaxdecomp.halo_exchange(arr,
halo_extents=(halo_size//2, halo_size//2, 0), halo_extents=(halo_size//2, halo_size//2, 0),
halo_periods=(True,True,True), halo_periods=(True,True,True))
pdims=sharding_info.pdims,
global_shape=(global_shape[0]+2*halo_size,
global_shape[1]+halo_size,
global_shape[2]))
# Apply correction along x # Apply correction along x
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2]) arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
@ -70,7 +60,6 @@ def meshgrid3d(shape, sharding_info=None):
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3]) return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, sharding_info=None): def zeros(shape, sharding_info=None):
""" Initialize an array of given global shape """ Initialize an array of given global shape
partitionned if need be accross dimensions. partitionned if need be accross dimensions.

View file

@ -4,76 +4,114 @@ import jax.lax as lax
from jaxpm.ops import halo_reduce from jaxpm.ops import halo_reduce
from jaxpm.kernels import fftk, cic_compensation from jaxpm.kernels import fftk, cic_compensation
import jaxdecomp
from functools import partial
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from jax.experimental.shard_map import shard_map
def cic_paint(mesh, positions, halo_size=0, sharding_info=None): @partial(jax.jit,static_argnums=(1))
def add_halo(positions , halo_size):
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
return positions
def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
print(f" positions {positions.shape}")
if sharding_info is not None: if sharding_info is not None:
# Add some padding for the halo exchange
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
[0, 0]])
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
out_specs=P('z', 'y'))
def sharded_pad(arr):
padded = jnp.pad(arr,pad_width=((halo_size, halo_size), (halo_size, halo_size), (0, 0)))
return padded
# Add some padding for the halo exchange
with gpu_mesh:
nbody_mesh = sharded_pad(nbody_mesh)
positions = add_halo(positions , halo_size)
with gpu_mesh:
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]]) [0., 1, 1], [1., 1, 1]]])
with gpu_mesh:
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = jnp.mod(neighboor_coords.reshape( neighboor_coords = jnp.mod(neighboor_coords.reshape(
[-1, 8, 3]).astype('int32'), jnp.array(mesh.shape)) [-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
dnums = jax.lax.ScatterDimensionNumbers( dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(), update_window_dims=(),
inserted_window_dims=(0, 1, 2), inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2)) scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
with gpu_mesh:
nbody_mesh = lax.scatter_add(nbody_mesh,
neighboor_coords, neighboor_coords,
kernel.reshape([-1, 8]), kernel.reshape([-1, 8]),
dnums) dnums)
if sharding_info == None: if sharding_info == None:
return mesh return nbody_mesh
else: else:
mesh = halo_reduce(mesh, sharding_info) with gpu_mesh :
return mesh[halo_size:-halo_size, halo_size:-halo_size] nbody_mesh = halo_reduce(nbody_mesh, sharding_info)
nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size]
return nbody_mesh
@jax.jit
def reduce_and_sum(mesh,neighboor_coords,kernel):
return (mesh[neighboor_coords[..., 0],
neighboor_coords[..., 1],
neighboor_coords[..., 3]]*kernel).sum(axis=-1)
def cic_read(mesh, positions, halo_size=0, sharding_info=None):
def cic_read(gpu_mesh , mesh, positions, halo_size=0, sharding_info=None):
""" Paints positions onto mesh """ Paints positions onto mesh
mesh: [nx, ny, nz] mesh: [nx, ny, nz]
positions: [npart, 3] positions: [npart, 3]
""" """
@partial(shard_map, mesh=gpu_mesh, in_specs=(P('z', 'y'),P()),
out_specs=P('z', 'y'))
def sharded_pad(arr , padding_width):
return jnp.pad(arr,pad_width=padding_width)
if sharding_info is not None: if sharding_info is not None:
# Add some padding and perfom hao exchange to retrieve # Add some padding and perfom hao exchange to retrieve
# neighboring regions # neighboring regions
mesh = jnp.pad(mesh, [[halo_size, halo_size],
[halo_size, halo_size],
[0, 0]])
# mesh = halo_reduce(mesh, sharding_info) # mesh = halo_reduce(mesh, sharding_info)
import jaxdecomp with gpu_mesh:
padding_width = jnp.array([(halo_size, halo_size), (halo_size, halo_size), (0, 0)])
#mesh = sharded_pad(mesh,padding_width)
mesh = jaxdecomp.halo_exchange(mesh, mesh = jaxdecomp.halo_exchange(mesh,
halo_extents=sharding_info.halo_extents, halo_extents=sharding_info.halo_extents,
halo_periods=(True,True,True), halo_periods=(True,True,True))
pdims=sharding_info.pdims, positions = add_halo(positions , halo_size)
global_shape=sharding_info.global_shape)
positions += jnp.array([halo_size, halo_size, 0]).reshape([-1, 3])
with gpu_mesh:
positions = jnp.expand_dims(positions, 1) positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions) floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1], [0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]]) [0., 1, 1], [1., 1, 1]]])
with gpu_mesh:
neighboor_coords = floor + connection neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords) kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2] kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
@ -81,9 +119,9 @@ def cic_read(mesh, positions, halo_size=0, sharding_info=None):
neighboor_coords = jnp.mod( neighboor_coords = jnp.mod(
neighboor_coords.astype('int32'), jnp.array(mesh.shape)) neighboor_coords.astype('int32'), jnp.array(mesh.shape))
return (mesh[neighboor_coords[..., 0], reduced = reduce_and_sum(mesh,neighboor_coords,kernel)
neighboor_coords[..., 1],
neighboor_coords[..., 3]]*kernel).sum(axis=-1) return reduced
def cic_paint_2d(mesh, positions, weight): def cic_paint_2d(mesh, positions, weight):

View file

@ -8,9 +8,14 @@ from jaxpm.ops import fft3d, ifft3d, zeros, normal
from jaxpm.kernels import fftk, apply_gradient_laplace from jaxpm.kernels import fftk, apply_gradient_laplace
from jaxpm.painting import cic_paint, cic_read from jaxpm.painting import cic_paint, cic_read
from jaxpm.growth import growth_factor, growth_rate, dGfa from jaxpm.growth import growth_factor, growth_rate, dGfa
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from jax.experimental.shard_map import shard_map
from functools import partial
def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_info=None):
""" """
Computes gravitational forces on particles using a PM scheme Computes gravitational forces on particles using a PM scheme
""" """
@ -22,37 +27,87 @@ def pm_forces(positions, mesh_shape=None, delta_k=None, halo_size=0, sharding_in
# Computes gravitational forces # Computes gravitational forces
kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info) kvec = fftk(delta_k.shape, symmetric=False, sharding_info=sharding_info)
forces_k = apply_gradient_laplace(delta_k, kvec)
local_kx = kvec[0]
local_ky = kvec[1]
replicated_kz = kvec[2]
gspmd_kx = multihost_utils.host_local_array_to_global_array(local_kx ,mesh, P('z'))
gspmd_ky = multihost_utils.host_local_array_to_global_array(local_ky ,mesh, P('y'))
@partial(jax.jit,static_argnums=(1))
def ifft3d_c2r(forces_k , i):
return ifft3d(forces_k[..., i], sharding_info=sharding_info).real
forces = []
with mesh:
forces_k = apply_gradient_laplace(delta_k, gspmd_kx , gspmd_ky , replicated_kz)
# Interpolate forces at the position of particles # Interpolate forces at the position of particles
return jnp.stack([cic_read(ifft3d(forces_k[..., i], sharding_info=sharding_info).real,
positions, halo_size=halo_size, sharding_info=sharding_info) for i in range(3):
for i in range(3)], axis=-1) with mesh:
ifft_forces = ifft3d_c2r(forces_k , i)
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
forces.append(force)
print(f"Shape {ifft_forces.shape}")
return jnp.stack(forces)
def lpt(cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_info=None):
""" """
Computes first order LPT displacement Computes first order LPT displacement
""" """
initial_force = pm_forces( initial_force = pm_forces(mesh,
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info) positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
a = jnp.atleast_1d(a) a = jnp.atleast_1d(a)
dx = growth_factor(cosmo, a) * initial_force
p = a**2 * growth_rate(cosmo, a) * \ print(f"Shape initial {initial_conditions.shape}")
@jax.jit
def compute_dx(cosmo , i_force):
return growth_factor(cosmo, a) * i_force
@jax.jit
def compute_p(cosmo , dx):
return a**2 * growth_rate(cosmo, a) * \
jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx jnp.sqrt(jc.background.Esqr(cosmo, a)) * dx
f = a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
@jax.jit
def compute_f(cosmo , initial_force):
return a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a)) * \
dGfa(cosmo, a) * initial_force dGfa(cosmo, a) * initial_force
with mesh:
dx = compute_dx(cosmo , initial_force)
p = compute_p(cosmo , dx)
f = compute_f(cosmo , initial_force)
return dx, p, f return dx, p, f
def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None): @jax.jit
def interpolate(kfield, kx, ky, kz , k , pk):
return kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2), k, jnp.sqrt(pk))
def linear_field(cosmo, mesh, mesh_shape, box_size, key, sharding_info=None):
""" """
Generate initial conditions in Fourier space. Generate initial conditions in Fourier space.
""" """
# Sample normal field # Sample normal field
field = normal(key, mesh_shape, sharding_info=sharding_info) pdims = sharding_info.pdims
slice_shape = (mesh_shape[0] // pdims[1], mesh_shape[1] // pdims[0],mesh_shape[2])
slice_field = normal(key, slice_shape, sharding_info=sharding_info)
field = multihost_utils.host_local_array_to_global_array(
slice_field, mesh, P('z', 'y'))
# Transform to Fourier space # Transform to Fourier space
with mesh :
kfield = fft3d(field, sharding_info=sharding_info) kfield = fft3d(field, sharding_info=sharding_info)
# Rescaling k to physical units # Rescaling k to physical units
@ -68,11 +123,17 @@ def linear_field(cosmo, mesh_shape, box_size, key, sharding_info=None):
) / (box_size[0] * box_size[1] * box_size[2]) ) / (box_size[0] * box_size[1] * box_size[2])
# Multipliyng the field by the proper power spectrum # Multipliyng the field by the proper power spectrum
kfield = xmap(lambda kfield, kx, ky, kz:
kfield * jc.scipy.interpolate.interp(jnp.sqrt(kx**2+ky**2+kz**2), local_kx = kvec[0]
k, jnp.sqrt(pk)), local_ky = kvec[1]
in_axes=(('x', 'y', ...), ['x'], ['y'], [...]), replicated_kz = kvec[2]
out_axes=('x', 'y', ...))(kfield, kvec[0], kvec[1], kvec[2])
gspmd_kx = multihost_utils.host_local_array_to_global_array(local_kx ,mesh, P('z'))
gspmd_ky = multihost_utils.host_local_array_to_global_array(local_ky ,mesh, P('y'))
with mesh:
kfield = interpolate(kfield,gspmd_kx, gspmd_ky, replicated_kz ,k, pk)
return kfield return kfield

View file

@ -10,19 +10,20 @@ from jaxpm.painting import cic_paint
from jax.experimental.ode import odeint from jax.experimental.ode import odeint
import jax_cosmo as jc import jax_cosmo as jc
import time import time
from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
from functools import partial
### Setting up a whole bunch of things ####### ### Setting up a whole bunch of things #######
# Create communicators # Create communicators
world = MPI.COMM_WORLD world = MPI.COMM_WORLD
rank = world.Get_rank() rank = world.Get_rank()
size = world.Get_size() size = world.Get_size()
jax.config.update("jax_enable_x64", True)
# Here we assume clients are on the same node, so we restrict which device # Here we assume clients are on the same node, so we restrict which device
# they can use based on their rank # they can use based on their rank
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % (rank + 1) jax.distributed.initialize()
jaxdecomp.init()
# Setup random keys # Setup random keys
master_key = jax.random.PRNGKey(42) master_key = jax.random.PRNGKey(42)
@ -35,37 +36,57 @@ mesh_shape = (N, N, N)
box_size = [500, 500, 500] # Mpc/h box_size = [500, 500, 500] # Mpc/h
halo_size = 32 halo_size = 32
sharding_info = ShardingInfo(global_shape=mesh_shape, sharding_info = ShardingInfo(global_shape=mesh_shape,
pdims=(1,2), pdims=(2,2),
halo_extents=(halo_size, halo_size, 0), halo_extents=(halo_size, halo_size, 0),
rank=rank) rank=rank)
cosmo = jc.Planck15() cosmo = jc.Planck15()
a = 0.1 a = 0.1
@jax.jit devices = mesh_utils.create_device_mesh(sharding_info.pdims[::-1])
def run_sim(cosmo, key): mesh = Mesh(devices, axis_names=('z', 'y'))
initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
sharding_info=sharding_info) sharding_info=sharding_info)
init_field = ifft3d(initial_conditions, sharding_info=sharding_info).real @jax.jit
def ifft3d_c2r(initial_conditions):
return ifft3d(initial_conditions, sharding_info=sharding_info).real
@jax.jit
def compute_displacement(p , dx):
return p + dx
def run_sim(mesh , initial_conditions, cosmo, key):
with mesh:
init_field = ifft3d_c2r(initial_conditions)
# Initialize particles # Initialize particles
pos = meshgrid3d(mesh_shape, sharding_info=sharding_info) pos = meshgrid3d(mesh_shape, sharding_info=sharding_info)
# Initial displacement by LPT # Initial displacement by LPT
cosmo = jc.Planck15() cosmo = jc.Planck15()
dx, p, f = lpt(cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
dx, p, f = lpt(mesh , cosmo, pos, initial_conditions, a, halo_size=halo_size, sharding_info=sharding_info)
# And now, we run an actual nbody # And now, we run an actual nbody
res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info), #res = odeint(make_ode_fn(mesh_shape, halo_size, sharding_info),
[pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo, # [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
rtol=1e-3, atol=1e-3) # rtol=1e-3, atol=1e-3)
# Painting on a new mesh ## Painting on a new mesh
field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info), print(f"shape of p {p.shape}")
res[0][-1], halo_size, sharding_info=sharding_info) print(f"shape of dx {dx.shape}")
with mesh:
displacement = compute_displacement(p , dx)
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
field = cic_paint(mesh , empty_field,
displacement, halo_size, sharding_info=sharding_info)
# field = cic_paint(zeros(mesh_shape, sharding_info=sharding_info),
# pos+dx, halo_size, sharding_info=sharding_info)
return init_field, field return init_field, field
# initial_conditions = linear_field(cosmo, mesh_shape, box_size, key, # initial_conditions = linear_field(cosmo, mesh_shape, box_size, key,
@ -87,7 +108,8 @@ def run_sim(cosmo, key):
# pos+dx, halo_size, sharding_info=sharding_info) # pos+dx, halo_size, sharding_info=sharding_info)
# # Recover the real space initial conditions # # Recover the real space initial conditions
init_field, field = run_sim(cosmo, key) run_sim(mesh , initial_conditions,cosmo, key)
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
# import jaxdecomp # import jaxdecomp
# field = jaxdecomp.halo_exchange(field, # field = jaxdecomp.halo_exchange(field,
@ -102,6 +124,10 @@ init_field, field = run_sim(cosmo, key)
# time2 = time.time() # time2 = time.time()
# if rank == 0: # if rank == 0:
onp.save('simulation_%d.npy'%rank, field) #onp.save('simulation_%d.npy'%rank, field)
# print('Done in', time2-time1) # print('Done in', time2-time1)
print("Done")
jaxdecomp.finalize()