From 429813ad9250bd2dbade3cc974b9d79233aa84c9 Mon Sep 17 00:00:00 2001 From: EiffL Date: Sat, 22 Oct 2022 13:23:13 -0400 Subject: [PATCH] Implemented a few fixes to the FFT --- jaxpm/ops.py | 49 ++++++++++++++++++---------------- scripts/test_fft3d.py | 61 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 22 deletions(-) create mode 100644 scripts/test_fft3d.py diff --git a/jaxpm/ops.py b/jaxpm/ops.py index 68bf355..fe6e387 100644 --- a/jaxpm/ops.py +++ b/jaxpm/ops.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import mpi4jax -def fft3d(arr, token=None, comms=None): +def fft3d(arr, comms=None): """ Computes forward FFT, note that the output is transposed """ if comms is not None: @@ -20,7 +20,7 @@ def fft3d(arr, token=None, comms=None): else: arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx]) arr = arr.transpose([2, 1, 3, 0]) # [y, z, x] - arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token) + arr, token = mpi4jax.alltoall(arr, comm=comms[0]) arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x] # Second FFT along x @@ -35,15 +35,10 @@ def fft3d(arr, token=None, comms=None): arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y] # Third FFT along y - arr = jnp.fft.fft(arr) - - if comms == None: - return arr - else: - return arr, token + return jnp.fft.fft(arr) -def ifft3d(arr, token=None, comms=None): +def ifft3d(arr, comms=None): """ Let's assume that the data is distributed accross x """ if comms is not None: @@ -59,7 +54,7 @@ def ifft3d(arr, token=None, comms=None): else: arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny]) arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x] - arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token) + arr, token = mpi4jax.alltoall(arr, comm=comms[1]) arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x] # Second FFT along x @@ -73,22 +68,17 @@ def ifft3d(arr, token=None, comms=None): arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token) arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z] - # Third FFT along y - arr = jnp.fft.fft(arr) - - if comms == None: - return arr - else: - return arr, token + # Third FFT along z + return jnp.fft.ifft(arr) -def halo_reduce(arr, halo_size, token=None, comms=None): +def halo_reduce(arr, halo_size, comms=None): # Perform halo exchange along x rank_x = comms[0].Get_rank() margin = arr[-2*halo_size:] margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1, - comm=comms[0], token=token) + comm=comms[0]) arr = arr.at[:2*halo_size].add(margin) margin = arr[:2*halo_size] @@ -108,7 +98,8 @@ def halo_reduce(arr, halo_size, token=None, comms=None): comm=comms[1], token=token) arr = arr.at[:, -2*halo_size:].add(margin) - return arr, token + return arr + def zeros(shape, comms=None): """ Initialize an array of given global shape @@ -116,8 +107,22 @@ def zeros(shape, comms=None): """ if comms is None: return jnp.zeros(shape) - + nx = comms[0].Get_size() ny = comms[1].Get_size() - return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:])) + return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:])) + + +def normal(key, shape, comms=None): + """ Generates a normal variable for the given + global shape. + """ + if comms is None: + return jax.random.normal(key, shape) + + nx = comms[0].Get_size() + ny = comms[1].Get_size() + + return jax.random.normal(key, + [shape[0]//nx, shape[1]//ny]+list(shape[2:])) diff --git a/scripts/test_fft3d.py b/scripts/test_fft3d.py new file mode 100644 index 0000000..07cc4a5 --- /dev/null +++ b/scripts/test_fft3d.py @@ -0,0 +1,61 @@ +from mpi4py import MPI +import jax +import jax.numpy as jnp +import mpi4jax +from jaxpm.ops import fft3d, ifft3d, normal + +# Create communicators +world = MPI.COMM_WORLD +rank = world.Get_rank() +size = world.Get_size() + +cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2], + periods=[True, True]) +comms = [cart_comm.Sub([True, False]), + cart_comm.Sub([False, True])] + +if rank == 0: + print("Communication setup done!") + + +# Setup random keys +master_key = jax.random.PRNGKey(42) +key = jax.random.split(master_key, size)[rank] + +# Size of the FFT +N = 256 +mesh_shape = [N, N, N] + +# Generate a random gaussian variable for the global +# mesh shape +original_array = normal(key, mesh_shape, comms=comms) + +# Run a forward FFT +karray = jax.jit(lambda x: fft3d(x, comms=comms))(original_array) +rarray = jax.jit(lambda x: ifft3d(x, comms=comms))(karray) + +# Testing that the fft is indeed invertible +print("I'm ", rank, abs(rarray.real - original_array).mean()) + + +# Testing that the FFT is actually what we expect +total_array, token = mpi4jax.allgather(original_array, comm=comms[0]) +total_array = total_array.reshape([N, N//2, N]) +total_array, token = mpi4jax.allgather( + total_array.transpose([1, 0, 2]), comm=comms[1], token=token) +total_array = total_array.reshape([N, N, N]) +total_array = total_array.transpose([1, 0, 2]) + +total_karray, token = mpi4jax.allgather(karray, comm=comms[0], token=token) +total_karray = total_karray.reshape([N, N//2, N]) +total_karray, token = mpi4jax.allgather( + total_karray.transpose([1, 0, 2]), comm=comms[1], token=token) +total_karray = total_karray.reshape([N, N, N]) +total_karray = total_karray.transpose([1, 0, 2]) + +print('FFT test:', rank, abs(jnp.fft.fftn( + total_array).transpose([2, 0, 1]) - total_karray).mean()) + +if rank == 0: + print("For reference, the mean value of the fft is", jnp.abs(jnp.fft.fftn( + total_array)).mean())