JaxPM_highres/dev/test_pfft.py
2024-07-09 14:54:34 -04:00

96 lines
2.3 KiB
Python

# Can be executed with:
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
from functools import partial
import jax
import jax.lax as lax
import jax.numpy as jnp
import numpy as np
from jax.experimental.maps import Mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
jax.distributed.initialize()
cube_size = 2048
@partial(xmap,
in_axes=[...],
out_axes=['x', 'y', ...],
axis_sizes={
'x': cube_size,
'y': cube_size
},
axis_resources={
'x': 'nx',
'y': 'ny',
'key_x': 'nx',
'key_y': 'ny'
})
def pnormal(key):
return jax.random.normal(key, shape=[cube_size])
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pfft3d(mesh):
# [x, y, z]
mesh = jnp.fft.fft(mesh) # Transform on z
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.fft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
mesh = jnp.fft.fft(mesh) # Transform on y
# [z, x, y]
return mesh
@partial(xmap,
in_axes={
0: 'x',
1: 'y'
},
out_axes=['x', 'y', ...],
axis_resources={
'x': 'nx',
'y': 'ny'
})
@jax.jit
def pifft3d(mesh):
# [z, x, y]
mesh = jnp.fft.ifft(mesh) # Transform on y
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
mesh = jnp.fft.ifft(mesh) # Transform on x
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
mesh = jnp.fft.ifft(mesh) # Transform on z
# [x, y, z]
return mesh
key = jax.random.PRNGKey(42)
# keys = jax.random.split(key, 4).reshape((2,2,2))
# We reshape all our devices to the mesh shape we want
devices = np.array(jax.devices()).reshape((2, 4))
with Mesh(devices, ('nx', 'ny')):
mesh = pnormal(key)
kmesh = pfft3d(mesh)
kmesh.block_until_ready()
# jax.profiler.start_trace("tensorboard")
# with Mesh(devices, ('nx', 'ny')):
# mesh = pnormal(key)
# kmesh = pfft3d(mesh)
# kmesh.block_until_ready()
# jax.profiler.stop_trace()
print('Done')