mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Applying formatting
This commit is contained in:
parent
5f463450d1
commit
a2811c0606
15 changed files with 566 additions and 446 deletions
|
@ -1,57 +1,80 @@
|
|||
# Can be executed with:
|
||||
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
|
||||
import jax
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import jax.lax as lax
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental.maps import Mesh
|
||||
from jax.experimental.maps import Mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
from functools import partial
|
||||
|
||||
jax.distributed.initialize()
|
||||
|
||||
cube_size = 2048
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=[...],
|
||||
out_axes=['x','y', ...],
|
||||
axis_sizes={'x':cube_size, 'y':cube_size},
|
||||
axis_resources={'x': 'nx', 'y':'ny',
|
||||
'key_x':'nx', 'key_y':'ny'})
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_sizes={
|
||||
'x': cube_size,
|
||||
'y': cube_size
|
||||
},
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny',
|
||||
'key_x': 'nx',
|
||||
'key_y': 'ny'
|
||||
})
|
||||
def pnormal(key):
|
||||
return jax.random.normal(key, shape=[cube_size])
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={0:'x', 1:'y'},
|
||||
out_axes=['x','y', ...],
|
||||
axis_resources={'x': 'nx', 'y': 'ny'})
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pfft3d(mesh):
|
||||
# [x, y, z]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on z
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on y
|
||||
mesh = jnp.fft.fft(mesh) # Transform on z
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on y
|
||||
# [z, x, y]
|
||||
return mesh
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={0:'x', 1:'y'},
|
||||
out_axes=['x','y', ...],
|
||||
axis_resources={'x': 'nx', 'y': 'ny'})
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pifft3d(mesh):
|
||||
# [z, x, y]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on y
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on z
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on y
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on z
|
||||
# [x, y, z]
|
||||
return mesh
|
||||
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
# keys = jax.random.split(key, 4).reshape((2,2,2))
|
||||
|
||||
|
@ -68,6 +91,6 @@ with Mesh(devices, ('nx', 'ny')):
|
|||
# mesh = pnormal(key)
|
||||
# kmesh = pfft3d(mesh)
|
||||
# kmesh.block_until_ready()
|
||||
# jax.profiler.stop_trace()
|
||||
# jax.profiler.stop_trace()
|
||||
|
||||
print('Done')
|
||||
print('Done')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue