mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
pm ok
This commit is contained in:
parent
055ceedb7e
commit
179030377b
4 changed files with 63 additions and 38 deletions
46
jaxpm/ops.py
46
jaxpm/ops.py
|
@ -5,7 +5,10 @@ import jaxdecomp
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from jax import jit
|
||||||
|
from jax.experimental import mesh_utils, multihost_utils
|
||||||
|
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
|
||||||
|
from jax.experimental.shard_map import shard_map
|
||||||
@dataclass
|
@dataclass
|
||||||
class ShardingInfo:
|
class ShardingInfo:
|
||||||
"""Class for keeping track of the distribution strategy"""
|
"""Class for keeping track of the distribution strategy"""
|
||||||
|
@ -31,25 +34,44 @@ def ifft3d(arr, sharding_info=None):
|
||||||
arr = jaxdecomp.pifft3d(arr)
|
arr = jaxdecomp.pifft3d(arr)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
def halo_reduce(arr, sharding_info=None):
|
|
||||||
if sharding_info is None:
|
|
||||||
return arr
|
def halo_reduce(arr, halo_size , gpu_mesh):
|
||||||
halo_size = sharding_info.halo_extents[0]
|
|
||||||
global_shape = sharding_info.global_shape
|
with gpu_mesh:
|
||||||
arr = jaxdecomp.halo_exchange(arr,
|
arr = jaxdecomp.halo_exchange(arr,
|
||||||
halo_extents=(halo_size//2, halo_size//2, 0),
|
halo_extents=(halo_size//2, halo_size//2, 0),
|
||||||
halo_periods=(True,True,True))
|
halo_periods=(True,True,True))
|
||||||
|
|
||||||
# Apply correction along x
|
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
|
||||||
|
def apply_correction_x(arr):
|
||||||
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
||||||
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
|
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
|
||||||
|
|
||||||
# Apply correction along y
|
return arr
|
||||||
|
|
||||||
|
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
|
||||||
|
def apply_correction_y(arr):
|
||||||
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
|
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
|
||||||
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
|
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
|
||||||
|
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
|
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
|
||||||
|
def un_pad(arr):
|
||||||
|
return arr[halo_size:-halo_size, halo_size:-halo_size]
|
||||||
|
|
||||||
|
|
||||||
|
# Apply correction along x
|
||||||
|
arr = apply_correction_x(arr)
|
||||||
|
# Apply correction along y
|
||||||
|
arr = apply_correction_y(arr)
|
||||||
|
|
||||||
|
arr = un_pad(arr)
|
||||||
|
|
||||||
|
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def meshgrid3d(shape, sharding_info=None):
|
def meshgrid3d(shape, sharding_info=None):
|
||||||
if sharding_info is not None:
|
if sharding_info is not None:
|
||||||
|
@ -60,14 +82,18 @@ def meshgrid3d(shape, sharding_info=None):
|
||||||
|
|
||||||
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
||||||
|
|
||||||
def zeros(shape, sharding_info=None):
|
def zeros(mesh , shape, sharding_info=None):
|
||||||
""" Initialize an array of given global shape
|
""" Initialize an array of given global shape
|
||||||
partitionned if need be accross dimensions.
|
partitionned if need be accross dimensions.
|
||||||
"""
|
"""
|
||||||
if sharding_info is None:
|
if sharding_info is None:
|
||||||
return jnp.zeros(shape)
|
return jnp.zeros(shape)
|
||||||
|
|
||||||
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
|
zeros_slice = jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], \
|
||||||
|
sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
|
||||||
|
|
||||||
|
gspmd_zeros = multihost_utils.host_local_array_to_global_array(zeros_slice ,mesh, P('z' , 'y'))
|
||||||
|
return gspmd_zeros
|
||||||
|
|
||||||
|
|
||||||
def normal(key, shape, sharding_info=None):
|
def normal(key, shape, sharding_info=None):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import jax
|
import jax
|
||||||
|
from jax import jit
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax.lax as lax
|
import jax.lax as lax
|
||||||
|
|
||||||
|
@ -22,7 +23,6 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
||||||
mesh: [nx, ny, nz]
|
mesh: [nx, ny, nz]
|
||||||
positions: [npart, 3]
|
positions: [npart, 3]
|
||||||
"""
|
"""
|
||||||
print(f" positions {positions.shape}")
|
|
||||||
if sharding_info is not None:
|
if sharding_info is not None:
|
||||||
|
|
||||||
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
|
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
|
||||||
|
@ -34,20 +34,25 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
||||||
# Add some padding for the halo exchange
|
# Add some padding for the halo exchange
|
||||||
with gpu_mesh:
|
with gpu_mesh:
|
||||||
nbody_mesh = sharded_pad(nbody_mesh)
|
nbody_mesh = sharded_pad(nbody_mesh)
|
||||||
|
|
||||||
positions = add_halo(positions , halo_size)
|
positions = add_halo(positions , halo_size)
|
||||||
|
|
||||||
with gpu_mesh:
|
with gpu_mesh:
|
||||||
positions = jnp.expand_dims(positions, 1)
|
positions = jnp.expand_dims(positions, 1)
|
||||||
floor = jnp.floor(positions)
|
floor = jit(jnp.floor)(positions)
|
||||||
|
|
||||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||||
[0., 1, 1], [1., 1, 1]]])
|
[0., 1, 1], [1., 1, 1]]])
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def compute_kernels(positions , neighboor_coords):
|
||||||
|
kernel = (1. - jnp.abs(positions - neighboor_coords))
|
||||||
|
return (kernel[..., 0] * kernel[..., 1] * kernel[..., 2])
|
||||||
|
|
||||||
with gpu_mesh:
|
with gpu_mesh:
|
||||||
neighboor_coords = floor + connection
|
neighboor_coords = jit(jnp.add)(floor , connection)
|
||||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
kernel = compute_kernels(positions , neighboor_coords)
|
||||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
|
||||||
|
|
||||||
neighboor_coords = jnp.mod(neighboor_coords.reshape(
|
neighboor_coords = jnp.mod(neighboor_coords.reshape(
|
||||||
[-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
|
[-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
|
||||||
|
@ -66,9 +71,7 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
||||||
if sharding_info == None:
|
if sharding_info == None:
|
||||||
return nbody_mesh
|
return nbody_mesh
|
||||||
else:
|
else:
|
||||||
with gpu_mesh :
|
nbody_mesh = halo_reduce(nbody_mesh, sharding_info.halo_extents[0] , gpu_mesh)
|
||||||
nbody_mesh = halo_reduce(nbody_mesh, sharding_info)
|
|
||||||
nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size]
|
|
||||||
|
|
||||||
return nbody_mesh
|
return nbody_mesh
|
||||||
|
|
||||||
|
|
|
@ -50,9 +50,8 @@ def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, shar
|
||||||
|
|
||||||
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
|
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
|
||||||
forces.append(force)
|
forces.append(force)
|
||||||
print(f"Shape {ifft_forces.shape}")
|
|
||||||
|
|
||||||
return jnp.stack(forces)
|
return jnp.stack(forces , axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,7 +63,6 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
|
||||||
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
|
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
|
||||||
a = jnp.atleast_1d(a)
|
a = jnp.atleast_1d(a)
|
||||||
|
|
||||||
print(f"Shape initial {initial_conditions.shape}")
|
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def compute_dx(cosmo , i_force):
|
def compute_dx(cosmo , i_force):
|
||||||
|
@ -85,6 +83,8 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
|
||||||
p = compute_p(cosmo , dx)
|
p = compute_p(cosmo , dx)
|
||||||
f = compute_f(cosmo , initial_force)
|
f = compute_f(cosmo , initial_force)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return dx, p, f
|
return dx, p, f
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
import os
|
import os
|
||||||
import jax
|
import jax
|
||||||
|
from jax import jit
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as onp
|
import numpy as onp
|
||||||
import jaxdecomp
|
import jaxdecomp
|
||||||
|
@ -53,12 +54,6 @@ initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
|
||||||
def ifft3d_c2r(initial_conditions):
|
def ifft3d_c2r(initial_conditions):
|
||||||
return ifft3d(initial_conditions, sharding_info=sharding_info).real
|
return ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||||
|
|
||||||
@jax.jit
|
|
||||||
def compute_displacement(p , dx):
|
|
||||||
return p + dx
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def run_sim(mesh , initial_conditions, cosmo, key):
|
def run_sim(mesh , initial_conditions, cosmo, key):
|
||||||
|
|
||||||
with mesh:
|
with mesh:
|
||||||
|
@ -77,12 +72,10 @@ def run_sim(mesh , initial_conditions, cosmo, key):
|
||||||
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||||
# rtol=1e-3, atol=1e-3)
|
# rtol=1e-3, atol=1e-3)
|
||||||
## Painting on a new mesh
|
## Painting on a new mesh
|
||||||
print(f"shape of p {p.shape}")
|
|
||||||
print(f"shape of dx {dx.shape}")
|
|
||||||
with mesh:
|
with mesh:
|
||||||
displacement = compute_displacement(p , dx)
|
displacement = jit(jnp.add)(p , dx)
|
||||||
|
|
||||||
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
|
empty_field = zeros(mesh , mesh_shape, sharding_info=sharding_info)
|
||||||
|
|
||||||
field = cic_paint(mesh , empty_field,
|
field = cic_paint(mesh , empty_field,
|
||||||
displacement, halo_size, sharding_info=sharding_info)
|
displacement, halo_size, sharding_info=sharding_info)
|
||||||
|
@ -108,7 +101,7 @@ def run_sim(mesh , initial_conditions, cosmo, key):
|
||||||
# pos+dx, halo_size, sharding_info=sharding_info)
|
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||||
|
|
||||||
# # Recover the real space initial conditions
|
# # Recover the real space initial conditions
|
||||||
run_sim(mesh , initial_conditions,cosmo, key)
|
init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
|
||||||
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
|
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
|
||||||
|
|
||||||
# import jaxdecomp
|
# import jaxdecomp
|
||||||
|
@ -123,8 +116,11 @@ run_sim(mesh , initial_conditions,cosmo, key)
|
||||||
# init_field.block_until_ready()
|
# init_field.block_until_ready()
|
||||||
# time2 = time.time()
|
# time2 = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# if rank == 0:
|
# if rank == 0:
|
||||||
#onp.save('simulation_%d.npy'%rank, field)
|
onp.save('simulation_init_field_float16_%d.npy'%rank, init_field.addressable_data(0).astype(onp.float16))
|
||||||
|
onp.save('simulation_field_float16_%d.npy'%rank, field.addressable_data(0).astype(onp.float16))
|
||||||
|
|
||||||
# print('Done in', time2-time1)
|
# print('Done in', time2-time1)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue