forked from guilhem_lavaux/JaxPM
96 lines
2.3 KiB
Python
96 lines
2.3 KiB
Python
# Can be executed with:
|
|
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
|
|
from functools import partial
|
|
|
|
import jax
|
|
import jax.lax as lax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
from jax.experimental.maps import Mesh, xmap
|
|
from jax.experimental.pjit import PartitionSpec, pjit
|
|
|
|
jax.distributed.initialize()
|
|
|
|
cube_size = 2048
|
|
|
|
|
|
@partial(xmap,
|
|
in_axes=[...],
|
|
out_axes=['x', 'y', ...],
|
|
axis_sizes={
|
|
'x': cube_size,
|
|
'y': cube_size
|
|
},
|
|
axis_resources={
|
|
'x': 'nx',
|
|
'y': 'ny',
|
|
'key_x': 'nx',
|
|
'key_y': 'ny'
|
|
})
|
|
def pnormal(key):
|
|
return jax.random.normal(key, shape=[cube_size])
|
|
|
|
|
|
@partial(xmap,
|
|
in_axes={
|
|
0: 'x',
|
|
1: 'y'
|
|
},
|
|
out_axes=['x', 'y', ...],
|
|
axis_resources={
|
|
'x': 'nx',
|
|
'y': 'ny'
|
|
})
|
|
@jax.jit
|
|
def pfft3d(mesh):
|
|
# [x, y, z]
|
|
mesh = jnp.fft.fft(mesh) # Transform on z
|
|
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
|
mesh = jnp.fft.fft(mesh) # Transform on x
|
|
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
|
mesh = jnp.fft.fft(mesh) # Transform on y
|
|
# [z, x, y]
|
|
return mesh
|
|
|
|
|
|
@partial(xmap,
|
|
in_axes={
|
|
0: 'x',
|
|
1: 'y'
|
|
},
|
|
out_axes=['x', 'y', ...],
|
|
axis_resources={
|
|
'x': 'nx',
|
|
'y': 'ny'
|
|
})
|
|
@jax.jit
|
|
def pifft3d(mesh):
|
|
# [z, x, y]
|
|
mesh = jnp.fft.ifft(mesh) # Transform on y
|
|
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
|
mesh = jnp.fft.ifft(mesh) # Transform on x
|
|
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
|
mesh = jnp.fft.ifft(mesh) # Transform on z
|
|
# [x, y, z]
|
|
return mesh
|
|
|
|
|
|
key = jax.random.PRNGKey(42)
|
|
# keys = jax.random.split(key, 4).reshape((2,2,2))
|
|
|
|
# We reshape all our devices to the mesh shape we want
|
|
devices = np.array(jax.devices()).reshape((2, 4))
|
|
|
|
with Mesh(devices, ('nx', 'ny')):
|
|
mesh = pnormal(key)
|
|
kmesh = pfft3d(mesh)
|
|
kmesh.block_until_ready()
|
|
|
|
# jax.profiler.start_trace("tensorboard")
|
|
# with Mesh(devices, ('nx', 'ny')):
|
|
# mesh = pnormal(key)
|
|
# kmesh = pfft3d(mesh)
|
|
# kmesh.block_until_ready()
|
|
# jax.profiler.stop_trace()
|
|
|
|
print('Done')
|