mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
pm ok
This commit is contained in:
parent
055ceedb7e
commit
179030377b
4 changed files with 63 additions and 38 deletions
56
jaxpm/ops.py
56
jaxpm/ops.py
|
@ -5,7 +5,10 @@ import jaxdecomp
|
|||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
from functools import partial
|
||||
|
||||
from jax import jit
|
||||
from jax.experimental import mesh_utils, multihost_utils
|
||||
from jax.sharding import Mesh, PartitionSpec as P,NamedSharding
|
||||
from jax.experimental.shard_map import shard_map
|
||||
@dataclass
|
||||
class ShardingInfo:
|
||||
"""Class for keeping track of the distribution strategy"""
|
||||
|
@ -31,22 +34,41 @@ def ifft3d(arr, sharding_info=None):
|
|||
arr = jaxdecomp.pifft3d(arr)
|
||||
return arr
|
||||
|
||||
def halo_reduce(arr, sharding_info=None):
|
||||
if sharding_info is None:
|
||||
|
||||
|
||||
def halo_reduce(arr, halo_size , gpu_mesh):
|
||||
|
||||
with gpu_mesh:
|
||||
arr = jaxdecomp.halo_exchange(arr,
|
||||
halo_extents=(halo_size//2, halo_size//2, 0),
|
||||
halo_periods=(True,True,True))
|
||||
|
||||
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
|
||||
def apply_correction_x(arr):
|
||||
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
||||
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
|
||||
|
||||
return arr
|
||||
halo_size = sharding_info.halo_extents[0]
|
||||
global_shape = sharding_info.global_shape
|
||||
arr = jaxdecomp.halo_exchange(arr,
|
||||
halo_extents=(halo_size//2, halo_size//2, 0),
|
||||
halo_periods=(True,True,True))
|
||||
|
||||
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
|
||||
def apply_correction_y(arr):
|
||||
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
|
||||
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
|
||||
|
||||
return arr
|
||||
|
||||
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),out_specs=P('z', 'y'))
|
||||
def un_pad(arr):
|
||||
return arr[halo_size:-halo_size, halo_size:-halo_size]
|
||||
|
||||
|
||||
# Apply correction along x
|
||||
arr = arr.at[halo_size:halo_size + halo_size//2].add(arr[ :halo_size//2])
|
||||
arr = arr.at[-halo_size - halo_size//2:-halo_size].add(arr[-halo_size//2:])
|
||||
|
||||
arr = apply_correction_x(arr)
|
||||
# Apply correction along y
|
||||
arr = arr.at[:, halo_size:halo_size + halo_size//2].add(arr[:, :halo_size//2][:, :])
|
||||
arr = arr.at[:, -halo_size - halo_size//2:-halo_size].add(arr[:, -halo_size//2:][:, :])
|
||||
arr = apply_correction_y(arr)
|
||||
|
||||
arr = un_pad(arr)
|
||||
|
||||
|
||||
return arr
|
||||
|
||||
|
@ -60,14 +82,18 @@ def meshgrid3d(shape, sharding_info=None):
|
|||
|
||||
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
||||
|
||||
def zeros(shape, sharding_info=None):
|
||||
def zeros(mesh , shape, sharding_info=None):
|
||||
""" Initialize an array of given global shape
|
||||
partitionned if need be accross dimensions.
|
||||
"""
|
||||
if sharding_info is None:
|
||||
return jnp.zeros(shape)
|
||||
|
||||
return jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
|
||||
zeros_slice = jnp.zeros([sharding_info.global_shape[0]//sharding_info.pdims[1], \
|
||||
sharding_info.global_shape[1]//sharding_info.pdims[0]]+list(sharding_info.global_shape[2:]))
|
||||
|
||||
gspmd_zeros = multihost_utils.host_local_array_to_global_array(zeros_slice ,mesh, P('z' , 'y'))
|
||||
return gspmd_zeros
|
||||
|
||||
|
||||
def normal(key, shape, sharding_info=None):
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import jax
|
||||
from jax import jit
|
||||
import jax.numpy as jnp
|
||||
import jax.lax as lax
|
||||
|
||||
|
@ -22,7 +23,6 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
|||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
print(f" positions {positions.shape}")
|
||||
if sharding_info is not None:
|
||||
|
||||
@partial(shard_map, mesh=gpu_mesh, in_specs=P('z', 'y'),
|
||||
|
@ -34,20 +34,25 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
|||
# Add some padding for the halo exchange
|
||||
with gpu_mesh:
|
||||
nbody_mesh = sharded_pad(nbody_mesh)
|
||||
|
||||
positions = add_halo(positions , halo_size)
|
||||
|
||||
with gpu_mesh:
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
floor = jit(jnp.floor)(positions)
|
||||
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||
[0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
@jit
|
||||
def compute_kernels(positions , neighboor_coords):
|
||||
kernel = (1. - jnp.abs(positions - neighboor_coords))
|
||||
return (kernel[..., 0] * kernel[..., 1] * kernel[..., 2])
|
||||
|
||||
with gpu_mesh:
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
neighboor_coords = jit(jnp.add)(floor , connection)
|
||||
kernel = compute_kernels(positions , neighboor_coords)
|
||||
|
||||
neighboor_coords = jnp.mod(neighboor_coords.reshape(
|
||||
[-1, 8, 3]).astype('int32'), jnp.array(nbody_mesh.shape))
|
||||
|
@ -66,9 +71,7 @@ def cic_paint(gpu_mesh,nbody_mesh, positions, halo_size=0, sharding_info=None):
|
|||
if sharding_info == None:
|
||||
return nbody_mesh
|
||||
else:
|
||||
with gpu_mesh :
|
||||
nbody_mesh = halo_reduce(nbody_mesh, sharding_info)
|
||||
nbody_mesh = nbody_mesh[halo_size:-halo_size, halo_size:-halo_size]
|
||||
nbody_mesh = halo_reduce(nbody_mesh, sharding_info.halo_extents[0] , gpu_mesh)
|
||||
|
||||
return nbody_mesh
|
||||
|
||||
|
|
|
@ -50,9 +50,8 @@ def pm_forces(mesh , positions, mesh_shape=None, delta_k=None, halo_size=0, shar
|
|||
|
||||
force = cic_read(mesh , ifft_forces, positions, halo_size=halo_size, sharding_info=sharding_info)
|
||||
forces.append(force)
|
||||
print(f"Shape {ifft_forces.shape}")
|
||||
|
||||
return jnp.stack(forces)
|
||||
return jnp.stack(forces , axis=-1)
|
||||
|
||||
|
||||
|
||||
|
@ -64,7 +63,6 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
|
|||
positions, delta_k=initial_conditions, halo_size=halo_size, sharding_info=sharding_info)
|
||||
a = jnp.atleast_1d(a)
|
||||
|
||||
print(f"Shape initial {initial_conditions.shape}")
|
||||
|
||||
@jax.jit
|
||||
def compute_dx(cosmo , i_force):
|
||||
|
@ -85,6 +83,8 @@ def lpt(mesh ,cosmo, positions, initial_conditions, a, halo_size=0, sharding_inf
|
|||
p = compute_p(cosmo , dx)
|
||||
f = compute_f(cosmo , initial_force)
|
||||
|
||||
|
||||
|
||||
return dx, p, f
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from mpi4py import MPI
|
||||
import os
|
||||
import jax
|
||||
from jax import jit
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
import jaxdecomp
|
||||
|
@ -53,12 +54,6 @@ initial_conditions = linear_field(cosmo, mesh, mesh_shape, box_size, key,
|
|||
def ifft3d_c2r(initial_conditions):
|
||||
return ifft3d(initial_conditions, sharding_info=sharding_info).real
|
||||
|
||||
@jax.jit
|
||||
def compute_displacement(p , dx):
|
||||
return p + dx
|
||||
|
||||
|
||||
|
||||
def run_sim(mesh , initial_conditions, cosmo, key):
|
||||
|
||||
with mesh:
|
||||
|
@ -77,12 +72,10 @@ def run_sim(mesh , initial_conditions, cosmo, key):
|
|||
# [pos+dx, p], jnp.linspace(0.1, 1.0, 2), cosmo,
|
||||
# rtol=1e-3, atol=1e-3)
|
||||
## Painting on a new mesh
|
||||
print(f"shape of p {p.shape}")
|
||||
print(f"shape of dx {dx.shape}")
|
||||
with mesh:
|
||||
displacement = compute_displacement(p , dx)
|
||||
displacement = jit(jnp.add)(p , dx)
|
||||
|
||||
empty_field = zeros(mesh_shape, sharding_info=sharding_info)
|
||||
empty_field = zeros(mesh , mesh_shape, sharding_info=sharding_info)
|
||||
|
||||
field = cic_paint(mesh , empty_field,
|
||||
displacement, halo_size, sharding_info=sharding_info)
|
||||
|
@ -108,7 +101,7 @@ def run_sim(mesh , initial_conditions, cosmo, key):
|
|||
# pos+dx, halo_size, sharding_info=sharding_info)
|
||||
|
||||
# # Recover the real space initial conditions
|
||||
run_sim(mesh , initial_conditions,cosmo, key)
|
||||
init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
|
||||
#init_field, field = run_sim(mesh , initial_conditions,cosmo, key)
|
||||
|
||||
# import jaxdecomp
|
||||
|
@ -123,8 +116,11 @@ run_sim(mesh , initial_conditions,cosmo, key)
|
|||
# init_field.block_until_ready()
|
||||
# time2 = time.time()
|
||||
|
||||
|
||||
|
||||
# if rank == 0:
|
||||
#onp.save('simulation_%d.npy'%rank, field)
|
||||
onp.save('simulation_init_field_float16_%d.npy'%rank, init_field.addressable_data(0).astype(onp.float16))
|
||||
onp.save('simulation_field_float16_%d.npy'%rank, field.addressable_data(0).astype(onp.float16))
|
||||
|
||||
# print('Done in', time2-time1)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue