Implemented a few fixes to the FFT

This commit is contained in:
EiffL 2022-10-22 13:23:13 -04:00
parent 1948eae9ed
commit 429813ad92
2 changed files with 88 additions and 22 deletions

View file

@ -4,7 +4,7 @@ import jax.numpy as jnp
import mpi4jax
def fft3d(arr, token=None, comms=None):
def fft3d(arr, comms=None):
""" Computes forward FFT, note that the output is transposed
"""
if comms is not None:
@ -20,7 +20,7 @@ def fft3d(arr, token=None, comms=None):
else:
arr = arr.reshape(shape[:-1]+[nx, shape[-1] // nx])
arr = arr.transpose([2, 1, 3, 0]) # [y, z, x]
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
arr, token = mpi4jax.alltoall(arr, comm=comms[0])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [y, z, x]
# Second FFT along x
@ -35,15 +35,10 @@ def fft3d(arr, token=None, comms=None):
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z, x, y]
# Third FFT along y
arr = jnp.fft.fft(arr)
if comms == None:
return arr
else:
return arr, token
return jnp.fft.fft(arr)
def ifft3d(arr, token=None, comms=None):
def ifft3d(arr, comms=None):
""" Let's assume that the data is distributed accross x
"""
if comms is not None:
@ -59,7 +54,7 @@ def ifft3d(arr, token=None, comms=None):
else:
arr = arr.reshape(shape[:-1]+[ny, shape[-1] // ny])
arr = arr.transpose([2, 0, 3, 1]) # Now [z, y, x]
arr, token = mpi4jax.alltoall(arr, comm=comms[1], token=token)
arr, token = mpi4jax.alltoall(arr, comm=comms[1])
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [z,y,x]
# Second FFT along x
@ -73,22 +68,17 @@ def ifft3d(arr, token=None, comms=None):
arr, token = mpi4jax.alltoall(arr, comm=comms[0], token=token)
arr = arr.transpose([1, 2, 0, 3]).reshape(shape) # Now [x,y,z]
# Third FFT along y
arr = jnp.fft.fft(arr)
if comms == None:
return arr
else:
return arr, token
# Third FFT along z
return jnp.fft.ifft(arr)
def halo_reduce(arr, halo_size, token=None, comms=None):
def halo_reduce(arr, halo_size, comms=None):
# Perform halo exchange along x
rank_x = comms[0].Get_rank()
margin = arr[-2*halo_size:]
margin, token = mpi4jax.sendrecv(margin, margin, rank_x-1, rank_x+1,
comm=comms[0], token=token)
comm=comms[0])
arr = arr.at[:2*halo_size].add(margin)
margin = arr[:2*halo_size]
@ -108,7 +98,8 @@ def halo_reduce(arr, halo_size, token=None, comms=None):
comm=comms[1], token=token)
arr = arr.at[:, -2*halo_size:].add(margin)
return arr, token
return arr
def zeros(shape, comms=None):
""" Initialize an array of given global shape
@ -116,8 +107,22 @@ def zeros(shape, comms=None):
"""
if comms is None:
return jnp.zeros(shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
return jnp.zeros([shape[0]//nx, shape[1]//ny]+list(shape[2:]))
def normal(key, shape, comms=None):
""" Generates a normal variable for the given
global shape.
"""
if comms is None:
return jax.random.normal(key, shape)
nx = comms[0].Get_size()
ny = comms[1].Get_size()
return jax.random.normal(key,
[shape[0]//nx, shape[1]//ny]+list(shape[2:]))