apply formating

This commit is contained in:
Wassim KABALAN 2024-10-27 03:50:34 +01:00
parent c93894f561
commit 19011d0712
5 changed files with 22 additions and 15 deletions

View file

@ -33,7 +33,8 @@ sharding = NamedSharding(mesh, P('x', 'y'))
mesh_shape = [512, 512, 512]
box_size = [500., 500., 1000.]
halo_size = 64
snapshots = jnp.linspace(0.1,1.,2)
snapshots = jnp.linspace(0.1, 1., 2)
@jax.jit
def run_simulation(omega_c, sigma8):
@ -89,9 +90,11 @@ print(f"[{rank}] Solver stats: {solver_stats}")
# Gather the results
pm_dict = {"initial_conditions": all_gather(initial_conditions),
"lpt_displacements": all_gather(lpt_displacements),
"solver_stats": solver_stats}
pm_dict = {
"initial_conditions": all_gather(initial_conditions),
"lpt_displacements": all_gather(lpt_displacements),
"solver_stats": solver_stats
}
for i in range(len(ode_solutions)):
sol = ode_solutions[i]

View file

@ -62,11 +62,9 @@ def plot_fields_single_projection(fields_dict, sum_over=None):
slicing = tuple(slicing)
# Sum projection over axis 0 and plot
axes[i].imshow(
field[slicing].sum(axis=0) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]]
)
axes[i].imshow(field[slicing].sum(axis=0) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]])
axes[i].set_xlabel('Mpc/h')
axes[i].set_ylabel('Mpc/h')
axes[i].set_title(f"{name} projection 0")