mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-04 11:10:53 +00:00
remove outdated tests
This commit is contained in:
parent
158478cb4a
commit
f91aa931dc
3 changed files with 0 additions and 233 deletions
|
@ -1,69 +0,0 @@
|
|||
import argparse
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
|
||||
# Setting up distributed jax
|
||||
jax.distributed.initialize()
|
||||
rank = jax.process_index()
|
||||
size = jax.process_count()
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax_cosmo as jc
|
||||
from jax.experimental import mesh_utils
|
||||
from jax.sharding import Mesh
|
||||
|
||||
from jaxpm.painting import cic_paint
|
||||
from jaxpm.pm import linear_field, lpt
|
||||
|
||||
mesh_shape = [256, 256, 256]
|
||||
box_size = [256., 256., 256.]
|
||||
snapshots = jnp.linspace(0.1, 1., 2)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def run_simulation(omega_c, sigma8, seed):
|
||||
# Create a cosmology
|
||||
cosmo = jc.Planck15(Omega_c=omega_c, sigma8=sigma8)
|
||||
|
||||
# Create a small function to generate the matter power spectrum
|
||||
k = jnp.logspace(-4, 1, 128)
|
||||
pk = jc.power.linear_matter_power(
|
||||
jc.Planck15(Omega_c=omega_c, sigma8=sigma8), k)
|
||||
pk_fn = lambda x: jc.scipy.interpolate.interp(x.reshape([-1]), k, pk
|
||||
).reshape(x.shape)
|
||||
|
||||
# Create initial conditions
|
||||
initial_conditions = linear_field(mesh_shape, box_size, pk_fn, seed=seed)
|
||||
|
||||
# Initialize particle displacements
|
||||
dx, p, f = lpt(cosmo, initial_conditions, 1.0)
|
||||
|
||||
field = cic_paint(jnp.zeros_like(initial_conditions), dx)
|
||||
return field
|
||||
|
||||
|
||||
def main(args):
|
||||
# Setting up distributed random numbers
|
||||
master_key = jax.random.PRNGKey(42)
|
||||
key = jax.random.split(master_key, size)[rank]
|
||||
|
||||
# Create computing mesh and sharding information
|
||||
devices = mesh_utils.create_device_mesh((2, 2))
|
||||
mesh = Mesh(devices.T, axis_names=('x', 'y'))
|
||||
|
||||
# Run the simulation on the compute mesh
|
||||
with mesh:
|
||||
field = run_simulation(0.32, 0.8, key)
|
||||
|
||||
print('done')
|
||||
np.save(f'field_{rank}.npy', field.addressable_data(0))
|
||||
|
||||
# Closing distributed jax
|
||||
jax.distributed.shutdown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("Distributed LPT N-body simulation.")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -1,96 +0,0 @@
|
|||
# Can be executed with:
|
||||
# srun -n 4 -c 32 --gpus-per-task 1 --gpu-bind=none python test_pfft.py
|
||||
from functools import partial
|
||||
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.experimental.maps import Mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
|
||||
jax.distributed.initialize()
|
||||
|
||||
cube_size = 2048
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes=[...],
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_sizes={
|
||||
'x': cube_size,
|
||||
'y': cube_size
|
||||
},
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny',
|
||||
'key_x': 'nx',
|
||||
'key_y': 'ny'
|
||||
})
|
||||
def pnormal(key):
|
||||
return jax.random.normal(key, shape=[cube_size])
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pfft3d(mesh):
|
||||
# [x, y, z]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on z
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now y is exposed, [z,x,y]
|
||||
mesh = jnp.fft.fft(mesh) # Transform on y
|
||||
# [z, x, y]
|
||||
return mesh
|
||||
|
||||
|
||||
@partial(xmap,
|
||||
in_axes={
|
||||
0: 'x',
|
||||
1: 'y'
|
||||
},
|
||||
out_axes=['x', 'y', ...],
|
||||
axis_resources={
|
||||
'x': 'nx',
|
||||
'y': 'ny'
|
||||
})
|
||||
@jax.jit
|
||||
def pifft3d(mesh):
|
||||
# [z, x, y]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on y
|
||||
mesh = lax.all_to_all(mesh, 'y', 0, 0) # Now x is exposed, [z,y,x]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on x
|
||||
mesh = lax.all_to_all(mesh, 'x', 0, 0) # Now z is exposed, [x,y,z]
|
||||
mesh = jnp.fft.ifft(mesh) # Transform on z
|
||||
# [x, y, z]
|
||||
return mesh
|
||||
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
# keys = jax.random.split(key, 4).reshape((2,2,2))
|
||||
|
||||
# We reshape all our devices to the mesh shape we want
|
||||
devices = np.array(jax.devices()).reshape((2, 4))
|
||||
|
||||
with Mesh(devices, ('nx', 'ny')):
|
||||
mesh = pnormal(key)
|
||||
kmesh = pfft3d(mesh)
|
||||
kmesh.block_until_ready()
|
||||
|
||||
# jax.profiler.start_trace("tensorboard")
|
||||
# with Mesh(devices, ('nx', 'ny')):
|
||||
# mesh = pnormal(key)
|
||||
# kmesh = pfft3d(mesh)
|
||||
# kmesh.block_until_ready()
|
||||
# jax.profiler.stop_trace()
|
||||
|
||||
print('Done')
|
|
@ -1,68 +0,0 @@
|
|||
# Start this script with:
|
||||
# mpirun -np 4 python test_script.py
|
||||
import os
|
||||
|
||||
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
|
||||
import jax
|
||||
import jax.lax as lax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pylab as plt
|
||||
import numpy as np
|
||||
import tensorflow_probability as tfp
|
||||
from jax.experimental.maps import mesh, xmap
|
||||
from jax.experimental.pjit import PartitionSpec, pjit
|
||||
|
||||
tfp = tfp.substrates.jax
|
||||
tfd = tfp.distributions
|
||||
|
||||
|
||||
def cic_paint(mesh, positions):
|
||||
""" Paints positions onto mesh
|
||||
mesh: [nx, ny, nz]
|
||||
positions: [npart, 3]
|
||||
"""
|
||||
positions = jnp.expand_dims(positions, 1)
|
||||
floor = jnp.floor(positions)
|
||||
connection = jnp.array([[[0, 0, 0], [1., 0, 0], [0., 1, 0], [0., 0, 1],
|
||||
[1., 1, 0], [1., 0, 1], [0., 1, 1], [1., 1, 1]]])
|
||||
|
||||
neighboor_coords = floor + connection
|
||||
kernel = 1. - jnp.abs(positions - neighboor_coords)
|
||||
kernel = kernel[..., 0] * kernel[..., 1] * kernel[..., 2]
|
||||
|
||||
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(),
|
||||
inserted_window_dims=(0, 1, 2),
|
||||
scatter_dims_to_operand_dims=(0, 1,
|
||||
2))
|
||||
mesh = lax.scatter_add(
|
||||
mesh,
|
||||
neighboor_coords.reshape([-1, 8, 3]).astype('int32'),
|
||||
kernel.reshape([-1, 8]), dnums)
|
||||
return mesh
|
||||
|
||||
|
||||
# And let's draw some points from some 3D distribution
|
||||
dist = tfd.MultivariateNormalDiag(loc=[16., 16., 16.],
|
||||
scale_identity_multiplier=3.)
|
||||
pos = dist.sample(1e4, seed=jax.random.PRNGKey(0))
|
||||
|
||||
f = pjit(lambda x: cic_paint(x, pos),
|
||||
in_axis_resources=PartitionSpec('x', 'y', 'z'),
|
||||
out_axis_resources=None)
|
||||
|
||||
devices = np.array(jax.devices()).reshape((2, 2, 1))
|
||||
|
||||
# Let's import the mesh
|
||||
m = jnp.zeros([32, 32, 32])
|
||||
|
||||
with mesh(devices, ('x', 'y', 'z')):
|
||||
# Shard the mesh, I'm not sure this is absolutely necessary
|
||||
m = pjit(lambda x: x,
|
||||
in_axis_resources=None,
|
||||
out_axis_resources=PartitionSpec('x', 'y', 'z'))(m)
|
||||
|
||||
# Apply the sharded CiC function
|
||||
res = f(m)
|
||||
|
||||
plt.imshow(res.sum(axis=2))
|
||||
plt.show()
|
Loading…
Add table
Reference in a new issue