run precommit

This commit is contained in:
Wassim KABALAN 2024-07-08 00:22:29 +02:00
parent 38e875a7df
commit 5b443e95c5
11 changed files with 481 additions and 380 deletions

View file

@ -1,48 +1,53 @@
# Start this script with:
# mpirun -np 4 python test_script.py
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
import matplotlib.pylab as plt
import jax
import numpy as np
import jax.numpy as jnp
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.pylab as plt
import numpy as np
import tensorflow_probability as tfp
from jax.experimental.maps import mesh, xmap
from jax.experimental.pjit import PartitionSpec, pjit
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfp = tfp.substrates.jax
tfd = tfp.distributions
def cic_paint(mesh, positions):
""" Paints positions onto mesh
""" Paints positions onto mesh
mesh: [nx, ny, nz]
positions: [npart, 3]
"""
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
[0., 0, 1], [1., 1, 0], [1., 0, 1],
[0., 1, 1], [1., 1, 1]]])
positions = jnp.expand_dims(positions, 1)
floor = jnp.floor(positions)
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
neighboor_coords = floor + connection
kernel = 1. - jnp.abs(positions - neighboor_coords)
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1,
2))
mesh = lax.scatter_add(
mesh,
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
kernel.reshape([-1, 8]), dnums)
return mesh
dnums = jax.lax.ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0, 1, 2),
scatter_dims_to_operand_dims=(0, 1, 2))
mesh = lax.scatter_add(mesh,
neighboor_coords.reshape([-1,8,3]).astype('int32'),
kernel.reshape([-1,8]),
dnums)
return mesh
# And let's draw some points from some 3D distribution
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.)
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
scale_identity_multiplier=3.)
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
f = pjit(lambda x: cic_paint(x, pos),
in_axis_resources=PartitionSpec('x', 'y', 'z'),
in_axis_resources=PartitionSpec('x', 'y', 'z'),
out_axis_resources=None)
devices = np.array(jax.devices()).reshape((2, 2, 1))
@ -51,13 +56,13 @@ devices = np.array(jax.devices()).reshape((2, 2, 1))
m = jnp.zeros([32, 32, 32])
with mesh(devices, ('x', 'y', 'z')):
# Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
# Shard the mesh, I'm not sure this is absolutely necessary
m = pjit(lambda x: x,
in_axis_resources=None,
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
# Apply the sharded CiC function
res = f(m)
# Apply the sharded CiC function
res = f(m)
plt.imshow(res.sum(axis=2))
plt.show()
plt.show()