Adding begnning of implem

This commit is contained in:
EiffL 2022-10-22 11:30:25 -05:00
parent 3c1abbafcd
commit 1948eae9ed
3 changed files with 68 additions and 87 deletions

View file

@ -100,12 +100,24 @@ def halo_reduce(arr, halo_size, token=None, comms=None):
rank_y = comms[1].Get_rank()
margin = arr[:, -2*halo_size:]
margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1,
comm=comms[0], token=token)
comm=comms[1], token=token)
arr = arr.at[:, :2*halo_size].add(margin)
margin = arr[:, :2*halo_size]
margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1,
comm=comms[0], token=token)
comm=comms[1], token=token)
arr = arr.at[:, -2*halo_size:].add(margin)
return arr, token
def zeros(shape, comms=None):
""" Initialize an array of given global shape
partitionned if need be accross dimensions.
"""
if comms is None:
return jnp.zeros(shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))