mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-24 11:50:53 +00:00
Implemented a few fixes to the FFT
This commit is contained in:
parent
1948eae9ed
commit
429813ad92
2 changed files with 88 additions and 22 deletions
45
jaxpm/ops.py
45
jaxpm/ops.py
|
@ -4,7 +4,7 @@ import jax.numpy as jnp
|
||||||
import mpi4jax
|
import mpi4jax
|
||||||
|
|
||||||
|
|
||||||
def fft3d(arr, token=None, comms=None):
|
def fft3d(arr, comms=None):
|
||||||
""" Computes forward FFT, note that the output is transposed
|
""" Computes forward FFT, note that the output is transposed
|
||||||
"""
|
"""
|
||||||
if comms is not None:
|
if comms is not None:
|
||||||
|
@ -20,7 +20,7 @@ def fft3d(arr, token=None, comms=None):
|
||||||
else:
|
else:
|
||||||
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
|
||||||
arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
|
arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
|
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
|
||||||
|
|
||||||
# Second FFT along x
|
# Second FFT along x
|
||||||
|
@ -35,15 +35,10 @@ def fft3d(arr, token=None, comms=None):
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
|
||||||
|
|
||||||
# Third FFT along y
|
# Third FFT along y
|
||||||
arr = jnp.fft.fft(arr)
|
return jnp.fft.fft(arr)
|
||||||
|
|
||||||
if comms == None:
|
|
||||||
return arr
|
|
||||||
else:
|
|
||||||
return arr, token
|
|
||||||
|
|
||||||
|
|
||||||
def ifft3d(arr, token=None, comms=None):
|
def ifft3d(arr, comms=None):
|
||||||
""" Let's assume that the data is distributed accross x
|
""" Let's assume that the data is distributed accross x
|
||||||
"""
|
"""
|
||||||
if comms is not None:
|
if comms is not None:
|
||||||
|
@ -59,7 +54,7 @@ def ifft3d(arr, token=None, comms=None):
|
||||||
else:
|
else:
|
||||||
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
|
||||||
arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
|
arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
|
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
|
||||||
|
|
||||||
# Second FFT along x
|
# Second FFT along x
|
||||||
|
@ -73,22 +68,17 @@ def ifft3d(arr, token=None, comms=None):
|
||||||
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
|
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
|
||||||
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
|
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
|
||||||
|
|
||||||
# Third FFT along y
|
# Third FFT along z
|
||||||
arr = jnp.fft.fft(arr)
|
return jnp.fft.ifft(arr)
|
||||||
|
|
||||||
if comms == None:
|
|
||||||
return arr
|
|
||||||
else:
|
|
||||||
return arr, token
|
|
||||||
|
|
||||||
|
|
||||||
def halo_reduce(arr, halo_size, token=None, comms=None):
|
def halo_reduce(arr, halo_size, comms=None):
|
||||||
|
|
||||||
# Perform halo exchange along x
|
# Perform halo exchange along x
|
||||||
rank_x = comms[0].Get_rank()
|
rank_x = comms[0].Get_rank()
|
||||||
margin = arr[-2*halo_size:]
|
margin = arr[-2*halo_size:]
|
||||||
margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1,
|
margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1,
|
||||||
comm=comms[0], token=token)
|
comm=comms[0])
|
||||||
arr = arr.at[:2*halo_size].add(margin)
|
arr = arr.at[:2*halo_size].add(margin)
|
||||||
|
|
||||||
margin = arr[:2*halo_size]
|
margin = arr[:2*halo_size]
|
||||||
|
@ -108,7 +98,8 @@ def halo_reduce(arr, halo_size, token=None, comms=None):
|
||||||
comm=comms[1], token=token)
|
comm=comms[1], token=token)
|
||||||
arr = arr.at[:, -2*halo_size:].add(margin)
|
arr = arr.at[:, -2*halo_size:].add(margin)
|
||||||
|
|
||||||
return arr, token
|
return arr
|
||||||
|
|
||||||
|
|
||||||
def zeros(shape, comms=None):
|
def zeros(shape, comms=None):
|
||||||
""" Initialize an array of given global shape
|
""" Initialize an array of given global shape
|
||||||
|
@ -121,3 +112,17 @@ def zeros(shape, comms=None):
|
||||||
ny = comms[1].Get_size()
|
ny = comms[1].Get_size()
|
||||||
|
|
||||||
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
|
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
|
||||||
|
|
||||||
|
|
||||||
|
def normal(key, shape, comms=None):
|
||||||
|
""" Generates a normal variable for the given
|
||||||
|
global shape.
|
||||||
|
"""
|
||||||
|
if comms is None:
|
||||||
|
return jax.random.normal(key, shape)
|
||||||
|
|
||||||
|
nx = comms[0].Get_size()
|
||||||
|
ny = comms[1].Get_size()
|
||||||
|
|
||||||
|
return jax.random.normal(key,
|
||||||
|
[shape[0]//nx, shape[1]//ny]+list(shape[2:]))
|
||||||
|
|
61
scripts/test_fft3d.py
Normal file
61
scripts/test_fft3d.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
from mpi4py import MPI
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import mpi4jax
|
||||||
|
from jaxpm.ops import fft3d, ifft3d, normal
|
||||||
|
|
||||||
|
# Create communicators
|
||||||
|
world = MPI.COMM_WORLD
|
||||||
|
rank = world.Get_rank()
|
||||||
|
size = world.Get_size()
|
||||||
|
|
||||||
|
cart_comm = MPI.COMM_WORLD.Create_cart(dims=[2, 2],
|
||||||
|
periods=[True, True])
|
||||||
|
comms = [cart_comm.Sub([True, False]),
|
||||||
|
cart_comm.Sub([False, True])]
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print("Communication setup done!")
|
||||||
|
|
||||||
|
|
||||||
|
# Setup random keys
|
||||||
|
master_key = jax.random.PRNGKey(42)
|
||||||
|
key = jax.random.split(master_key, size)[rank]
|
||||||
|
|
||||||
|
# Size of the FFT
|
||||||
|
N = 256
|
||||||
|
mesh_shape = [N, N, N]
|
||||||
|
|
||||||
|
# Generate a random gaussian variable for the global
|
||||||
|
# mesh shape
|
||||||
|
original_array = normal(key, mesh_shape, comms=comms)
|
||||||
|
|
||||||
|
# Run a forward FFT
|
||||||
|
karray = jax.jit(lambda x: fft3d(x, comms=comms))(original_array)
|
||||||
|
rarray = jax.jit(lambda x: ifft3d(x, comms=comms))(karray)
|
||||||
|
|
||||||
|
# Testing that the fft is indeed invertible
|
||||||
|
print("I'm ", rank, abs(rarray.real - original_array).mean())
|
||||||
|
|
||||||
|
|
||||||
|
# Testing that the FFT is actually what we expect
|
||||||
|
total_array, token = mpi4jax.allgather(original_array, comm=comms[0])
|
||||||
|
total_array = total_array.reshape([N, N//2, N])
|
||||||
|
total_array, token = mpi4jax.allgather(
|
||||||
|
total_array.transpose([1, 0, 2]), comm=comms[1], token=token)
|
||||||
|
total_array = total_array.reshape([N, N, N])
|
||||||
|
total_array = total_array.transpose([1, 0, 2])
|
||||||
|
|
||||||
|
total_karray, token = mpi4jax.allgather(karray, comm=comms[0], token=token)
|
||||||
|
total_karray = total_karray.reshape([N, N//2, N])
|
||||||
|
total_karray, token = mpi4jax.allgather(
|
||||||
|
total_karray.transpose([1, 0, 2]), comm=comms[1], token=token)
|
||||||
|
total_karray = total_karray.reshape([N, N, N])
|
||||||
|
total_karray = total_karray.transpose([1, 0, 2])
|
||||||
|
|
||||||
|
print('FFT test:', rank, abs(jnp.fft.fftn(
|
||||||
|
total_array).transpose([2, 0, 1]) - total_karray).mean())
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
print("For reference, the mean value of the fft is", jnp.abs(jnp.fft.fftn(
|
||||||
|
total_array)).mean())
|
Loading…
Add table
Reference in a new issue