fixed a whole lot of issues

This commit is contained in:
EiffL 2022-10-22 15:58:32 -04:00
parent 429813ad92
commit 72ae0fd88f
5 changed files with 251 additions and 155 deletions

View file

@ -73,34 +73,58 @@ def ifft3d(arr, comms=None):
def halo_reduce(arr, halo_size, comms=None):
if halo_size <= 0:
return arr
# Perform halo exchange along x
rank_x = comms[0].Get_rank()
size_x = comms[0].Get_size()
margin = arr[-2*halo_size:]
margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1,
comm=comms[0])
arr = arr.at[:2*halo_size].add(margin)
left, token = mpi4jax.sendrecv(margin, margin,
(rank_x-1) % size_x,
(rank_x+1) % size_x,
comm=comms[0])
margin = arr[:2*halo_size]
margin, token = mpi4jax.sendrecv(margin, margin, rank_x+1, rank_x-1,
comm=comms[0], token=token)
arr = arr.at[-2*halo_size:].add(margin)
right, token = mpi4jax.sendrecv(margin, margin,
(rank_x+1) % size_x,
(rank_x-1) % size_x,
comm=comms[0], token=token)
arr = arr.at[:2*halo_size].add(left)
arr = arr.at[-2*halo_size:].add(right)
# Perform halo exchange along y
rank_y = comms[1].Get_rank()
size_y = comms[1].Get_size()
margin = arr[:, -2*halo_size:]
margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1,
comm=comms[1], token=token)
arr = arr.at[:, :2*halo_size].add(margin)
left, token = mpi4jax.sendrecv(margin, margin,
(rank_y-1) % size_y,
(rank_y+1) % size_y,
comm=comms[1], token=token)
margin = arr[:, :2*halo_size]
margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1,
comm=comms[1], token=token)
arr = arr.at[:, -2*halo_size:].add(margin)
right, token = mpi4jax.sendrecv(margin, margin,
(rank_y+1) % size_y,
(rank_y-1) % size_y,
comm=comms[1], token=token)
arr = arr.at[:, :2*halo_size].add(left)
arr = arr.at[:, -2*halo_size:].add(right)
return arr
def meshgrid3d(shape, comms=None):
if comms is not None:
nx = comms[0].Get_size()
ny = comms[1].Get_size()
coords = [jnp.arange(shape[0]//nx),
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
else:
coords = [jnp.arange(s) for s in shape[2:]]
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
def zeros(shape, comms=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.