diff --git a/jaxpm/ops.py b/jaxpm/ops.py index 66c9cec..789ff69 100644 --- a/jaxpm/ops.py +++ b/jaxpm/ops.py @@ -19,7 +19,8 @@ def fft3d(arr, comms=None): arr = arr.transpose([1, 2, 0]) else: arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx]) - arr = arr.transpose([2, 1, 3, 0]) # [y, z, x] + #arr = arr.transpose([2, 1, 3, 0]) # [y, z, x] + arr = jnp.einsum('ij,xyjz->iyzx', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work arr, token = mpi4jax.alltoall(arr, comm=comms[0]) arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x] @@ -30,7 +31,8 @@ def fft3d(arr, comms=None): arr = arr.transpose([1, 2, 0]) else: arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny]) - arr = arr.transpose([2, 1, 3, 0]) # [z, x, y] + #arr = arr.transpose([2, 1, 3, 0]) # [z, x, y] + arr = jnp.einsum('ij,yzjx->izxy', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token) arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y] @@ -53,7 +55,8 @@ def ifft3d(arr, comms=None): arr = arr.transpose([0, 2, 1]) else: arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny]) - arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x] + # arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x] + arr = jnp.einsum('ij,zxjy->izyx', jnp.eye(ny), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work arr, token = mpi4jax.alltoall(arr, comm=comms[1]) arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x] @@ -64,7 +67,8 @@ def ifft3d(arr, comms=None): arr = arr.transpose([2, 1, 0]) else: arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx]) - arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z] + # arr = arr.transpose([2, 3, 1, 0]) # now [x, y, z] + arr = jnp.einsum('ij,zyjx->ixyz', jnp.eye(nx), arr) # TODO: remove this hack when we understand why transpose before alltoall doenst work arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token) arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]