mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
added fixes to alltoall
This commit is contained in:
parent
72ae0fd88f
commit
6644b35d71
1 changed files with 8 additions and 4 deletions
12
jaxpm/ops.py
12
jaxpm/ops.py
|
@ -19,7 +19,8 @@ def fft3d(arr, comms=None):
|
||||||
arr = arr.transpose([1, 2, 0])
|
arr = arr.transpose([1, 2, 0])
|
||||||
else:
|
else:
|
||||||
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
||||||
arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
|
#arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
|
||||||
|
arr = jnp.einsum('ij,xyjz->iyzx', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
|
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
|
||||||
|
|
||||||
|
@ -30,7 +31,8 @@ def fft3d(arr, comms=None):
|
||||||
arr = arr.transpose([1, 2, 0])
|
arr = arr.transpose([1, 2, 0])
|
||||||
else:
|
else:
|
||||||
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
||||||
arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
|
#arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
|
||||||
|
arr = jnp.einsum('ij,yzjx->izxy', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
|
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
|
||||||
|
|
||||||
|
@ -53,7 +55,8 @@ def ifft3d(arr, comms=None):
|
||||||
arr = arr.transpose([0, 2, 1])
|
arr = arr.transpose([0, 2, 1])
|
||||||
else:
|
else:
|
||||||
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
||||||
arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
|
# arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
|
||||||
|
arr = jnp.einsum('ij,zxjy->izyx', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
|
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
|
||||||
|
|
||||||
|
@ -64,7 +67,8 @@ def ifft3d(arr, comms=None):
|
||||||
arr = arr.transpose([2, 1, 0])
|
arr = arr.transpose([2, 1, 0])
|
||||||
else:
|
else:
|
||||||
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
||||||
arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
|
# arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
|
||||||
|
arr = jnp.einsum('ij,zyjx->ixyz', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
|
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue