JaxPM/dev/test_script.py
2024-07-09 14:54:34 -04:00

68 lines
2.1 KiB
Python

# Start this script with:
# mpirun -np 4 python test_script.py
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import tensorflow_probability as tfp
from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
tfp = tfp.substrates.jax
tfd = tfp.distributions
def cic_paint(mesh, positions):
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(
mesh,
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
kernel.reshape([-1, 8]), dnums)
return mesh
# And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
f = pjit(lambda x: cic_paint(x, pos),
in_axis_resources=PartitionSpec('x', 'y', 'z'),
out_axis_resources=None)
devices = np.array(jax.devices()).reshape((2, 2, 1))
# Let's import the mesh
m = jnp.zeros([32, 32, 32])
with mesh(devices, ('x', 'y', 'z')):
# Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
# Apply the sharded CiC function
res = f(m)
plt.imshow(res.sum(axis=2))
plt.show()