mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Adding begnning of implem
This commit is contained in:
parent
3c1abbafcd
commit
1948eae9ed
3 changed files with 68 additions and 87 deletions
16
jaxpm/ops.py
16
jaxpm/ops.py
|
@ -100,12 +100,24 @@ def halo_reduce(arr, halo_size, token=None, comms=None):
|
|||
rank_y = comms[1].Get_rank()
|
||||
margin = arr[:, -2*halo_size:]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1,
|
||||
comm=comms[0], token=token)
|
||||
comm=comms[1], token=token)
|
||||
arr = arr.at[:, :2*halo_size].add(margin)
|
||||
|
||||
margin = arr[:, :2*halo_size]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1,
|
||||
comm=comms[0], token=token)
|
||||
comm=comms[1], token=token)
|
||||
arr = arr.at[:, -2*halo_size:].add(margin)
|
||||
|
||||
return arr, token
|
||||
|
||||
def zeros(shape, comms=None):
|
||||
""" Initialize an array of given global shape
|
||||
partitionned if need be accross dimensions.
|
||||
"""
|
||||
if comms is None:
|
||||
return jnp.zeros(shape)
|
||||
|
||||
nx = comms[0].Get_size()
|
||||
ny = comms[1].Get_size()
|
||||
|
||||
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue