added fixes to alltoall

This commit is contained in:
EiffL 2022-11-05 23:25:01 +01:00
parent 72ae0fd88f
commit 6644b35d71

View file

@ -19,7 +19,8 @@ def fft3d(arr, comms=None):
arr = arr.transpose([1, 2, 0])
else:
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
#arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
arr = jnp.einsum('ij,xyjz->iyzx', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
@ -30,7 +31,8 @@ def fft3d(arr, comms=None):
arr = arr.transpose([1, 2, 0])
else:
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
#arr = arr.transpose([2, 1, 3, 0]) # [z, x, y]
arr = jnp.einsum('ij,yzjx->izxy', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
@ -53,7 +55,8 @@ def ifft3d(arr, comms=None):
arr = arr.transpose([0, 2, 1])
else:
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
# arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
arr = jnp.einsum('ij,zxjy->izyx', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
@ -64,7 +67,8 @@ def ifft3d(arr, comms=None):
arr = arr.transpose([2, 1, 0])
else:
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
# arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z]
arr = jnp.einsum('ij,zyjx->ixyz', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]