From b43cb373a0b37b474d14c1a97a8e39e42f8c2def Mon Sep 17 00:00:00 2001 From: Wassim Kabalan Date: Fri, 28 Feb 2025 09:56:00 +0100 Subject: [PATCH] format --- notebooks/05-MultiHost_PM.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/notebooks/05-MultiHost_PM.py b/notebooks/05-MultiHost_PM.py index 19863ad..c41d1cf 100644 --- a/notebooks/05-MultiHost_PM.py +++ b/notebooks/05-MultiHost_PM.py @@ -18,7 +18,8 @@ import numpy as np from diffrax import (ConstantStepSize, Dopri5, LeapfrogMidpoint, ODETerm, PIDController, SaveAt, diffeqsolve) from jax.experimental.multihost_utils import process_allgather -from jax.sharding import PartitionSpec as P, NamedSharding +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P from jaxpm.kernels import interpolate_power_spectrum from jaxpm.painting import cic_paint_dx @@ -104,7 +105,10 @@ def run_simulation(omega_c, sigma8, mesh_shape, box_size, halo_size, sharding=sharding) ode_fn = ODETerm( - make_diffrax_ode(mesh_shape, paint_absolute_pos=False,sharding=sharding , halo_size=halo_size)) + make_diffrax_ode(mesh_shape, + paint_absolute_pos=False, + sharding=sharding, + halo_size=halo_size)) # Choose solver solver = LeapfrogMidpoint() if solver_choice == "leapfrog" else Dopri5()