mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 08:31:11 +00:00
Implemented a few fixes to the FFT
This commit is contained in:
parent
1948eae9ed
commit
429813ad92
2 changed files with 88 additions and 22 deletions
61
scripts/test_fft3d.py
Normal file
61
scripts/test_fft3d.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
from mpi4py import MPI
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import mpi4jax
|
||||
from jaxpm.ops import fft3d, ifft3d, normal
|
||||
|
||||
# Create communicators
|
||||
world = MPI.COMM_WORLD
|
||||
rank = world.Get_rank()
|
||||
size = world.Get_size()
|
||||
|
||||
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
|
||||
periods=[True, True])
|
||||
comms = [cart_comm.Sub([True, False]),
|
||||
cart_comm.Sub([False, True])]
|
||||
|
||||
if rank == 0:
|
||||
print("Communication setup done!")
|
||||
|
||||
|
||||
# Setup random keys
|
||||
master_key = jax.random.PRNGKey(42)
|
||||
key = jax.random.split(master_key, size)[rank]
|
||||
|
||||
# Size of the FFT
|
||||
N = 256
|
||||
mesh_shape = [N, N, N]
|
||||
|
||||
# Generate a random gaussian variable for the global
|
||||
# mesh shape
|
||||
original_array = normal(key, mesh_shape, comms=comms)
|
||||
|
||||
# Run a forward FFT
|
||||
karray = jax.jit(lambda x: fft3d(x, comms=comms))(original_array)
|
||||
rarray = jax.jit(lambda x: ifft3d(x, comms=comms))(karray)
|
||||
|
||||
# Testing that the fft is indeed invertible
|
||||
print("I'm ", rank, abs(rarray.real - original_array).mean())
|
||||
|
||||
|
||||
# Testing that the FFT is actually what we expect
|
||||
total_array, token = mpi4jax.allgather(original_array, comm=comms[0])
|
||||
total_array = total_array.reshape([N, N//2, N])
|
||||
total_array, token = mpi4jax.allgather(
|
||||
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
|
||||
total_array = total_array.reshape([N, N, N])
|
||||
total_array = total_array.transpose([1, 0, 2])
|
||||
|
||||
total_karray, token = mpi4jax.allgather(karray, comm=comms[0], token=token)
|
||||
total_karray = total_karray.reshape([N, N//2, N])
|
||||
total_karray, token = mpi4jax.allgather(
|
||||
total_karray.transpose([1, 0, 2]), comm=comms[1], token=token)
|
||||
total_karray = total_karray.reshape([N, N, N])
|
||||
total_karray = total_karray.transpose([1, 0, 2])
|
||||
|
||||
print('FFT test:', rank, abs(jnp.fft.fftn(
|
||||
total_array).transpose([2, 0, 1]) - total_karray).mean())
|
||||
|
||||
if rank == 0:
|
||||
print("For reference, the mean value of the fft is", jnp.abs(jnp.fft.fftn(
|
||||
total_array)).mean())
|
Loading…
Add table
Add a link
Reference in a new issue