diff --git a/jaxpm/distributed.py b/jaxpm/distributed.py index 4fdb764..dbbb8fd 100644 --- a/jaxpm/distributed.py +++ b/jaxpm/distributed.py @@ -40,7 +40,7 @@ def ifft3d(x): def get_halo_size(halo_size, sharding): gpu_mesh = sharding.mesh if sharding is not None else None if gpu_mesh is None or gpu_mesh.empty: - zero_ext = (0, 0, 0) + zero_ext = (0, 0) zero_tuple = (0, 0) return (zero_tuple, zero_tuple, zero_tuple), zero_ext else: diff --git a/jaxpm/painting.py b/jaxpm/painting.py index 838fe38..fd52106 100644 --- a/jaxpm/painting.py +++ b/jaxpm/painting.py @@ -5,8 +5,8 @@ import jax.lax as lax import jax.numpy as jnp from jax.sharding import PartitionSpec as P -from jaxpm.distributed import (autoshmap, get_halo_size, halo_exchange, - slice_pad, slice_unpad, fft3d, ifft3d) +from jaxpm.distributed import (autoshmap, fft3d, get_halo_size, halo_exchange, + ifft3d, slice_pad, slice_unpad) from jaxpm.kernels import cic_compensation, fftk from jaxpm.painting_utils import gather, scatter diff --git a/notebooks/03-MultiHost_PM.py b/notebooks/03-MultiHost_PM.py index 06dccb7..9cda78b 100644 --- a/notebooks/03-MultiHost_PM.py +++ b/notebooks/03-MultiHost_PM.py @@ -7,21 +7,21 @@ jax.distributed.initialize() rank = jax.process_index() size = jax.process_count() +from functools import partial + import jax.numpy as jnp import jax_cosmo as jc - -from jaxpm.kernels import interpolate_power_spectrum -from jaxpm.painting import cic_paint_dx -from jaxpm.pm import linear_field, lpt, make_ode_fn - +import numpy as np +from diffrax import (ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, + diffeqsolve) from jax.experimental.mesh_utils import create_device_mesh from jax.experimental.multihost_utils import process_allgather from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P -from functools import partial -import numpy as np -from diffrax import ConstantStepSize, LeapfrogMidpoint, ODETerm, SaveAt, diffeqsolve +from jaxpm.kernels import interpolate_power_spectrum +from jaxpm.painting import cic_paint_dx +from jaxpm.pm import linear_field, lpt, make_ode_fn all_gather = partial(process_allgather, tiled=True) @@ -84,7 +84,8 @@ def run_simulation(omega_c, sigma8): return initial_conditions, dx, res.ys, res.stats -initial_conditions , lpt_displacements , ode_solutions , solver_stats = run_simulation(0.25, 0.8) +initial_conditions, lpt_displacements, ode_solutions, solver_stats = run_simulation( + 0.25, 0.8) print(f"[{rank}] Simulation completed") print(f"[{rank}] Solver stats: {solver_stats}") @@ -95,10 +96,9 @@ ode_solutions = [all_gather(sol) for sol in ode_solutions] if rank == 0: np.savez("multihost_pm.npz", - initial_conditions=initial_conditions, - lpt_displacements=lpt_displacements, - ode_solutions=ode_solutions, - solver_stats=solver_stats) + initial_conditions=initial_conditions, + lpt_displacements=lpt_displacements, + ode_solutions=ode_solutions, + solver_stats=solver_stats) print(f"[{rank}] Simulation results saved") - \ No newline at end of file diff --git a/notebooks/visualize.py b/notebooks/visualize.py index 0db2273..0a3884b 100644 --- a/notebooks/visualize.py +++ b/notebooks/visualize.py @@ -1,6 +1,7 @@ -import numpy as np import jax.numpy as jnp import matplotlib.pyplot as plt +import numpy as np + def plot_fields(fields_dict, sum_over=None): """ @@ -23,8 +24,10 @@ def plot_fields(fields_dict, sum_over=None): slicing = tuple(slicing) # Sum projection over the specified axis and plot - axes[row, proj_axis].imshow(field[slicing].sum(axis=proj_axis) + 1, - cmap='magma', extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]]) + axes[row, proj_axis].imshow( + field[slicing].sum(axis=proj_axis) + 1, + cmap='magma', + extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]]) axes[row, proj_axis].set_xlabel('Mpc/h') axes[row, proj_axis].set_ylabel('Mpc/h') axes[row, proj_axis].set_title(title) @@ -32,7 +35,8 @@ def plot_fields(fields_dict, sum_over=None): # Plot each field across the three axes for i, (name, field) in enumerate(fields_dict.items()): for proj_axis in range(3): - plot_subplots(proj_axis, field, i, f"{name} projection {proj_axis}") + plot_subplots(proj_axis, field, i, + f"{name} projection {proj_axis}") plt.tight_layout() plt.show() diff --git a/setup.py b/setup.py index 43b7315..b9b6b29 100644 --- a/setup.py +++ b/setup.py @@ -7,5 +7,5 @@ setup( author='JaxPM developers', description='A dead simple FastPM implementation in JAX', packages=find_packages(), - install_requires=['jax', 'jax_cosmo','jaxdecomp'], + install_requires=['jax', 'jax_cosmo', 'jaxdecomp'], )