JaxPM/dev/test_pfft.py

73 lines
2.1 KiB
Python
Raw Normal View History

# Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
import jax
import jax.numpy as jnp
import numpy as np
import jax.lax as lax
from jax.experimental.maps import xmap
from jax.experimental.maps import Mesh
from jax.experimental.pjit import PartitionSpec, pjit
from functools import partial
jax.distributed.initialize()
cube_size = 2048
@partial(xmap,
in_axes=[...],
out_axes=['x','y', ...],
axis_sizes={'x':cube_size, 'y':cube_size},
axis_resources={'x': 'nx', 'y':'ny',
'key_x':'nx', 'key_y':'ny'})
def pnormal(key):
return jax.random.normal(key, shape=[cube_size])
@partial(xmap,
in_axes={0:'x', 1:'y'},
out_axes=['x','y', ...],
axis_resources={'x': 'nx', 'y': 'ny'})
@jax.jit
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y]
return mesh
@partial(xmap,
in_axes={0:'x', 1:'y'},
out_axes=['x','y', ...],
axis_resources={'x': 'nx', 'y': 'ny'})
@jax.jit
def pifft3d(mesh):
# [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z]
return mesh
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2))
# We reshape all our devices to the mesh shape we want
devices = np.array(jax.devices()).reshape((2, 4))
with Mesh(devices, ('nx', 'ny')):
mesh = pnormal(key)
kmesh = pfft3d(mesh)
kmesh.block_until_ready()
# jax.profiler.start_trace("tensorboard")
# with Mesh(devices, ('nx', 'ny')):
# mesh = pnormal(key)
# kmesh = pfft3d(mesh)
# kmesh.block_until_ready()
# jax.profiler.stop_trace()
print('Done')