mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
fixed a whole lot of issues
This commit is contained in:
parent
429813ad92
commit
72ae0fd88f
5 changed files with 251 additions and 155 deletions
52
jaxpm/ops.py
52
jaxpm/ops.py
|
@ -73,34 +73,58 @@ def ifft3d(arr, comms=None):
|
|||
|
||||
|
||||
def halo_reduce(arr, halo_size, comms=None):
|
||||
if halo_size <= 0:
|
||||
return arr
|
||||
|
||||
# Perform halo exchange along x
|
||||
rank_x = comms[0].Get_rank()
|
||||
size_x = comms[0].Get_size()
|
||||
margin = arr[-2*halo_size:]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1,
|
||||
comm=comms[0])
|
||||
arr = arr.at[:2*halo_size].add(margin)
|
||||
|
||||
left, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_x-1) % size_x,
|
||||
(rank_x+1) % size_x,
|
||||
comm=comms[0])
|
||||
margin = arr[:2*halo_size]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_x+1, rank_x-1,
|
||||
comm=comms[0], token=token)
|
||||
arr = arr.at[-2*halo_size:].add(margin)
|
||||
right, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_x+1) % size_x,
|
||||
(rank_x-1) % size_x,
|
||||
comm=comms[0], token=token)
|
||||
|
||||
arr = arr.at[:2*halo_size].add(left)
|
||||
arr = arr.at[-2*halo_size:].add(right)
|
||||
|
||||
# Perform halo exchange along y
|
||||
rank_y = comms[1].Get_rank()
|
||||
size_y = comms[1].Get_size()
|
||||
margin = arr[:, -2*halo_size:]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_y-1, rank_y+1,
|
||||
comm=comms[1], token=token)
|
||||
arr = arr.at[:, :2*halo_size].add(margin)
|
||||
|
||||
left, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_y-1) % size_y,
|
||||
(rank_y+1) % size_y,
|
||||
comm=comms[1], token=token)
|
||||
margin = arr[:, :2*halo_size]
|
||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_y+1, rank_y-1,
|
||||
comm=comms[1], token=token)
|
||||
arr = arr.at[:, -2*halo_size:].add(margin)
|
||||
right, token = mpi4jax.sendrecv(margin, margin,
|
||||
(rank_y+1) % size_y,
|
||||
(rank_y-1) % size_y,
|
||||
comm=comms[1], token=token)
|
||||
arr = arr.at[:, :2*halo_size].add(left)
|
||||
arr = arr.at[:, -2*halo_size:].add(right)
|
||||
|
||||
return arr
|
||||
|
||||
|
||||
def meshgrid3d(shape, comms=None):
|
||||
if comms is not None:
|
||||
nx = comms[0].Get_size()
|
||||
ny = comms[1].Get_size()
|
||||
|
||||
coords = [jnp.arange(shape[0]//nx),
|
||||
jnp.arange(shape[1]//ny)] + [jnp.arange(s) for s in shape[2:]]
|
||||
else:
|
||||
coords = [jnp.arange(s) for s in shape[2:]]
|
||||
|
||||
return jnp.stack(jnp.meshgrid(*coords), axis=-1).reshape([-1, 3])
|
||||
|
||||
|
||||
def zeros(shape, comms=None):
|
||||
""" Initialize an array of given global shape
|
||||
partitionned if need be accross dimensions.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue