From 64894726e74a779471a66a811ad2be28fd8f8e30 Mon Sep 17 00:00:00 2001 From: EiffL Date: Mon, 17 Oct 2022 22:07:12 -0700 Subject: [PATCH] Added example code to compute distributed fft --- dev/job_pfft.sh | 14 ++++++++++ dev/test_pfft.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 dev/job_pfft.sh create mode 100644 dev/test_pfft.py diff --git a/dev/job_pfft.sh b/dev/job_pfft.sh new file mode 100644 index 0000000..a0b73c0 --- /dev/null +++ b/dev/job_pfft.sh @@ -0,0 +1,14 @@ +#!/bin/bash +#SBATCH -A m1727 +#SBATCH -C gpu +#SBATCH -q debug +#SBATCH -t 0:05:00 +#SBATCH -N 2 +#SBATCH --ntasks-per-node=4 +#SBATCH -c 32 +#SBATCH --gpus-per-task=1 +#SBATCH --gpu-bind=none + +module load python cudnn/8.2.0 nccl/2.11.4 cudatoolkit +export SLURM_CPU_BIND="cores" +srun python test_pfft.py diff --git a/dev/test_pfft.py b/dev/test_pfft.py new file mode 100644 index 0000000..873c238 --- /dev/null +++ b/dev/test_pfft.py @@ -0,0 +1,73 @@ +# Can be executed with: +# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py +import jax +import jax.numpy as jnp +import numpy as np +import jax.lax as lax +from jax.experimental.maps import xmap +from jax.experimental.maps import Mesh +from jax.experimental.pjit import PartitionSpec, pjit +from functools import partial + +jax.distributed.initialize() + +cube_size = 2048 + +@partial(xmap, + in_axes=[...], + out_axes=['x','y', ...], + axis_sizes={'x':cube_size, 'y':cube_size}, + axis_resources={'x': 'nx', 'y':'ny', + 'key_x':'nx', 'key_y':'ny'}) +def pnormal(key): + return jax.random.normal(key, shape=[cube_size]) + +@partial(xmap, + in_axes={0:'x', 1:'y'}, + out_axes=['x','y', ...], + axis_resources={'x': 'nx', 'y': 'ny'}) +@jax.jit +def pfft3d(mesh): + # [x, y, z] + mesh = jnp.fft.fft(mesh) # Transform on z + mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x] + mesh = jnp.fft.fft(mesh) # Transform on x + mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y] + mesh = jnp.fft.fft(mesh) # Transform on y + # [z, x, y] + return mesh + +@partial(xmap, + in_axes={0:'x', 1:'y'}, + out_axes=['x','y', ...], + axis_resources={'x': 'nx', 'y': 'ny'}) +@jax.jit +def pifft3d(mesh): + # [z, x, y] + mesh = jnp.fft.ifft(mesh) # Transform on y + mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x] + mesh = jnp.fft.ifft(mesh) # Transform on x + mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z] + mesh = jnp.fft.ifft(mesh) # Transform on z + # [x, y, z] + return mesh + +key = jax.random.PRNGKey(42) +# keys = jax.random.split(key, 4).reshape((2,2,2)) + +# We reshape all our devices to the mesh shape we want +devices = np.array(jax.devices()).reshape((2, 4)) + +with Mesh(devices, ('nx', 'ny')): + mesh = pnormal(key) + kmesh = pfft3d(mesh) + kmesh.block_until_ready() + +# jax.profiler.start_trace("tensorboard") +# with Mesh(devices, ('nx', 'ny')): +# mesh = pnormal(key) +# kmesh = pfft3d(mesh) +# kmesh.block_until_ready() +# jax.profiler.stop_trace() + +print('Done') \ No newline at end of file