From 6ab26ea1ec0ac151b37402ae14879085a718acce Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Tue, 21 Jan 2025 10:15:11 +0100 Subject: [PATCH] remove extra text --- tests/test_sharded_array.py | 149 ------------------------------------ 1 file changed, 149 deletions(-) delete mode 100644 tests/test_sharded_array.py diff --git a/tests/test_sharded_array.py b/tests/test_sharded_array.py deleted file mode 100644 index 47bc64a..0000000 --- a/tests/test_sharded_array.py +++ /dev/null @@ -1,149 +0,0 @@ -import os - -#os.environ["JAX_PLATFORM_NAME"] = "cpu" -#os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" - -os.environ["EQX_ON_ERROR"] = "nan" -from functools import partial - -import jax -import jax.numpy as jnp -import jax_cosmo as jc -from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, - diffeqsolve) -from jax.debug import visualize_array_sharding -from jax.experimental.mesh_utils import create_device_mesh -from jax.experimental.multihost_utils import process_allgather -from jax.sharding import Mesh, NamedSharding -from jax.sharding import PartitionSpec as P - -from jaxpm.distributed import uniform_particles -from jaxpm.kernels import interpolate_power_spectrum -from jaxpm.painting import cic_paint, cic_paint_dx, cic_read, cic_read_dx -from jaxpm.pm import linear_field, lpt, make_diffrax_ode - -#assert jax.device_count() >= 8, "This notebook requires a TPU or GPU runtime with 8 devices" - -all_gather = partial(process_allgather, tiled=False) - -pdims = (2, 4) -#devices = create_device_mesh(pdims) -#mesh = Mesh(devices, axis_names=('x', 'y')) -#sharding = NamedSharding(mesh, P('x', 'y')) -sharding = None - -from typing import NamedTuple - -from jaxdecomp import ShardedArray - -mesh_shape = 64 -box_size = 64. -halo_size = 2 -snapshots = (0.5, 1.0) - - -class Params(NamedTuple): - omega_c: float - sigma8: float - initial_conditions: jnp.ndarray - - -mesh_shape = (mesh_shape, ) * 3 -box_size = (box_size, ) * 3 -omega_c = 0.25 -sigma8 = 0.8 -# Create a small function to generate the matter power spectrum -k = jnp.logspace(-4, 1, 128) -pk = jc.power.linear_matter_power(jc.Planck15(Omega_c=omega_c, sigma8=sigma8), - k) -pk_fn = lambda x: interpolate_power_spectrum(x, k, pk, sharding) - -initial_conditions = linear_field(mesh_shape, - box_size, - pk_fn, - seed=jax.random.PRNGKey(0), - sharding=sharding) - -#initial_conditions = ShardedArray(initial_conditions, sharding) - -params = Params(omega_c, sigma8, initial_conditions) - - -@partial(jax.jit, static_argnums=(1, 2, 3, 4)) -def forward_model(params, mesh_shape, box_size, halo_size, snapshots): - - # Create initial conditions - cosmo = jc.Planck15(Omega_c=params.omega_c, sigma8=params.sigma8) - particles = uniform_particles(mesh_shape, sharding) - ic_structure = jax.tree.structure(params.initial_conditions) - particles = jax.tree.unflatten(ic_structure, jax.tree.leaves(particles)) - # Initial displacement - dx, p, f = lpt(cosmo, - params.initial_conditions, - particles, - a=0.1, - order=2, - halo_size=halo_size, - sharding=sharding) - - # Evolve the simulation forward - ode_fn = ODETerm( - make_diffrax_ode(mesh_shape, - paint_absolute_pos=True, - halo_size=halo_size, - sharding=sharding)) - solver = LeapfrogMidpoint() - - y0 = jax.tree.map( - lambda particles, dx, p: jnp.stack([particles + dx, p], axis=0), - particles, dx, p) - print(f"y0 structure: {jax.tree.structure(y0)}") - - stepsize_controller = ConstantStepSize() - res = diffeqsolve(ode_fn, - solver, - t0=0.1, - t1=1., - dt0=0.01, - y0=y0, - args=cosmo, - saveat=SaveAt(ts=snapshots), - stepsize_controller=stepsize_controller) - ode_solutions = [sol[0] for sol in res.ys] - - ode_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), - ode_solutions[-1]) - return particles + dx, ode_field - - ode_field = cic_paint_dx(ode_solutions[-1]) - return dx, ode_field - - -lpt_particles, ode_field = forward_model(params, mesh_shape, box_size, - halo_size, snapshots) - -import matplotlib.pyplot as plt - -lpt_field = cic_paint(jnp.zeros(mesh_shape, jnp.float32), lpt_particles) -#lpt_field = cic_paint_dx(lpt_particles) - -plt.figure(figsize=(12, 6)) -plt.subplot(121) -plt.imshow(lpt_field.sum(axis=0), cmap='magma') -plt.colorbar() -plt.title('LPT field') -plt.subplot(122) -plt.imshow(ode_field.sum(axis=0), cmap='magma') -plt.colorbar() -plt.title('ODE field') -plt.show() -plt.close() - -#particles = jax.random.uniform(jax.random.PRNGKey(0), (4 , 4 ,4 , 3), minval=0.1, maxval=0.9) -#field = jax.random.uniform(jax.random.PRNGKey(0), (4, 4, 4)) -# -#partiles = ShardedArray(particles, sharding) -#field = ShardedArray(field, sharding) -# -# -#cic_read_dx(field , particles )