mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 16:41:11 +00:00
Applying formatting
This commit is contained in:
parent
5f463450d1
commit
a2811c0606
15 changed files with 566 additions and 446 deletions
|
@ -1,48 +1,53 @@
|
|||
# Start this script with:
|
||||
# mpirun -np 4 python test_script.py
|
||||
import os
|
||||
|
||||
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
|
||||
import matplotlib.pylab as plt
|
||||
import jax
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
import tensorflow_probability as tfp
|
||||
from jax.experimental.maps import mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
|
||||
|
||||
tfp = tfp.substrates.jax
|
||||
tfd = tfp.distributions
|
||||
|
||||
|
||||
def cic_paint(mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0],
|
||||
[0., 0, 1], [1., 1, 0], [1., 0, 1],
|
||||
[0., 1, 1], [1., 1, 1]]])
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = lax.scatter_add(
|
||||
mesh,
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
return mesh
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(
|
||||
update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1, 2))
|
||||
mesh = lax.scatter_add(mesh,
|
||||
neighboor_coords.reshape([-1,8,3]).astype('int32'),
|
||||
kernel.reshape([-1,8]),
|
||||
dnums)
|
||||
return mesh
|
||||
|
||||
# And let's draw some points from some 3D distribution
|
||||
dist = tfd.MultivariateNormalDiag(loc=[16.,16.,16.], scale_identity_multiplier=3.)
|
||||
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
|
||||
scale_identity_multiplier=3.)
|
||||
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
|
||||
|
||||
f = pjit(lambda x: cic_paint(x, pos),
|
||||
in_axis_resources=PartitionSpec('x', 'y', 'z'),
|
||||
in_axis_resources=PartitionSpec('x', 'y', 'z'),
|
||||
out_axis_resources=None)
|
||||
|
||||
devices = np.array(jax.devices()).reshape((2, 2, 1))
|
||||
|
@ -51,13 +56,13 @@ devices = np.array(jax.devices()).reshape((2, 2, 1))
|
|||
m = jnp.zeros([32, 32, 32])
|
||||
|
||||
with mesh(devices, ('x', 'y', 'z')):
|
||||
# Shard the mesh, I'm not sure this is absolutely necessary
|
||||
m = pjit(lambda x: x,
|
||||
in_axis_resources=None,
|
||||
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
||||
# Shard the mesh, I'm not sure this is absolutely necessary
|
||||
m = pjit(lambda x: x,
|
||||
in_axis_resources=None,
|
||||
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
||||
|
||||
# Apply the sharded CiC function
|
||||
res = f(m)
|
||||
# Apply the sharded CiC function
|
||||
res = f(m)
|
||||
|
||||
plt.imshow(res.sum(axis=2))
|
||||
plt.show()
|
||||
plt.show()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue