JaxPM/scripts/test_fft3d.py
2022-10-22 13:23:13 -04:00

61 lines
1.9 KiB
Python

from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jax
from jaxpm.ops import fft3d, ifft3d, normal
# Create communicators
world = MPI.COMM_WORLD
rank = world.Get_rank()
size = world.Get_size()
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
periods=[True, True])
comms = [cart_comm.Sub([True, False]),
cart_comm.Sub([False, True])]
if rank == 0:
print("Communication setup done!")
# Setup random keys
master_key = jax.random.PRNGKey(42)
key = jax.random.split(master_key, size)[rank]
# Size of the FFT
N = 256
mesh_shape = [N, N, N]
# Generate a random gaussian variable for the global
# mesh shape
original_array = normal(key, mesh_shape, comms=comms)
# Run a forward FFT
karray = jax.jit(lambda x: fft3d(x, comms=comms))(original_array)
rarray = jax.jit(lambda x: ifft3d(x, comms=comms))(karray)
# Testing that the fft is indeed invertible
print("I'm ", rank, abs(rarray.real - original_array).mean())
# Testing that the FFT is actually what we expect
total_array, token = mpi4jax.allgather(original_array, comm=comms[0])
total_array = total_array.reshape([N, N//2, N])
total_array, token = mpi4jax.allgather(
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
total_array = total_array.reshape([N, N, N])
total_array = total_array.transpose([1, 0, 2])
total_karray, token = mpi4jax.allgather(karray, comm=comms[0], token=token)
total_karray = total_karray.reshape([N, N//2, N])
total_karray, token = mpi4jax.allgather(
total_karray.transpose([1, 0, 2]), comm=comms[1], token=token)
total_karray = total_karray.reshape([N, N, N])
total_karray = total_karray.transpose([1, 0, 2])
print('FFT test:', rank, abs(jnp.fft.fftn(
total_array).transpose([2, 0, 1]) - total_karray).mean())
if rank == 0:
print("For reference, the mean value of the fft is", jnp.abs(jnp.fft.fftn(
total_array)).mean())