mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
from mpi4py import MPI
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import mpi4jax
|
|
from jaxpm.ops import fft3d, ifft3d, normal
|
|
|
|
# Create communicators
|
|
world = MPI.COMM_WORLD
|
|
rank = world.Get_rank()
|
|
size = world.Get_size()
|
|
|
|
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
|
|
periods=[True, True])
|
|
comms = [cart_comm.Sub([True, False]),
|
|
cart_comm.Sub([False, True])]
|
|
|
|
if rank == 0:
|
|
print("Communication setup done!")
|
|
|
|
|
|
# Setup random keys
|
|
master_key = jax.random.PRNGKey(42)
|
|
key = jax.random.split(master_key, size)[rank]
|
|
|
|
# Size of the FFT
|
|
N = 256
|
|
mesh_shape = [N, N, N]
|
|
|
|
# Generate a random gaussian variable for the global
|
|
# mesh shape
|
|
original_array = normal(key, mesh_shape, comms=comms)
|
|
|
|
# Run a forward FFT
|
|
karray = jax.jit(lambda x: fft3d(x, comms=comms))(original_array)
|
|
rarray = jax.jit(lambda x: ifft3d(x, comms=comms))(karray)
|
|
|
|
# Testing that the fft is indeed invertible
|
|
print("I'm ", rank, abs(rarray.real - original_array).mean())
|
|
|
|
|
|
# Testing that the FFT is actually what we expect
|
|
total_array, token = mpi4jax.allgather(original_array, comm=comms[0])
|
|
total_array = total_array.reshape([N, N//2, N])
|
|
total_array, token = mpi4jax.allgather(
|
|
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
|
|
total_array = total_array.reshape([N, N, N])
|
|
total_array = total_array.transpose([1, 0, 2])
|
|
|
|
total_karray, token = mpi4jax.allgather(karray, comm=comms[0], token=token)
|
|
total_karray = total_karray.reshape([N, N//2, N])
|
|
total_karray, token = mpi4jax.allgather(
|
|
total_karray.transpose([1, 0, 2]), comm=comms[1], token=token)
|
|
total_karray = total_karray.reshape([N, N, N])
|
|
total_karray = total_karray.transpose([1, 0, 2])
|
|
|
|
print('FFT test:', rank, abs(jnp.fft.fftn(
|
|
total_array).transpose([2, 0, 1]) - total_karray).mean())
|
|
|
|
if rank == 0:
|
|
print("For reference, the mean value of the fft is", jnp.abs(jnp.fft.fftn(
|
|
total_array)).mean())
|