update notebooks

This commit is contained in:
Wassim KABALAN 2024-10-27 03:49:07 +01:00
parent b4fdb74660
commit c93894f561
5 changed files with 341 additions and 174 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -30,11 +30,10 @@ devices = create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('x', 'y'))
sharding = NamedSharding(mesh, P('x', 'y'))
mesh_shape = [2024, 1024, 1024]
box_size = [1024., 1024., 1024.]
halo_size = 512
snapshots = jnp.linspace(0.1, 1., 2)
mesh_shape = [512, 512, 512]
box_size = [500., 500., 1000.]
halo_size = 64
snapshots = jnp.linspace(0.1,1.,2)
@jax.jit
def run_simulation(omega_c, sigma8):
@ -59,8 +58,7 @@ def run_simulation(omega_c, sigma8):
# Initial displacement
dx, p, _ = lpt(cosmo,
initial_conditions,
particles,
0.1,
a=0.1,
halo_size=halo_size,
sharding=sharding)
@ -90,15 +88,16 @@ print(f"[{rank}] Simulation completed")
print(f"[{rank}] Solver stats: {solver_stats}")
# Gather the results
initial_conditions = all_gather(initial_conditions)
lpt_displacements = all_gather(lpt_displacements)
ode_solutions = [all_gather(sol) for sol in ode_solutions]
pm_dict = {"initial_conditions": all_gather(initial_conditions),
"lpt_displacements": all_gather(lpt_displacements),
"solver_stats": solver_stats}
for i in range(len(ode_solutions)):
sol = ode_solutions[i]
pm_dict[f"ode_solution_{i}"] = all_gather(sol)
if rank == 0:
np.savez("multihost_pm.npz",
initial_conditions=initial_conditions,
lpt_displacements=lpt_displacements,
ode_solutions=ode_solutions,
solver_stats=solver_stats)
np.savez("multihost_pm.npz", **pm_dict)
print(f"[{rank}] Simulation results saved")

View file

@ -40,3 +40,36 @@ def plot_fields(fields_dict, sum_over=None):
plt.tight_layout()
plt.show()
def plot_fields_single_projection(fields_dict, sum_over=None):
"""
Plots a single projection (along axis 0) of 3D fields in one row,
summing over the first `sum_over` elements along the 0-axis.
Args:
- fields_dict: dictionary where keys are field names and values are 3D arrays
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_cols = len(fields_dict)
fig, axes = plt.subplots(1, nb_cols, figsize=(5 * nb_cols, 5))
for i, (name, field) in enumerate(fields_dict.items()):
# Define the slice for the 0-axis projection
slicing = [slice(None)] * field.ndim
slicing[0] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over axis 0 and plot
axes[i].imshow(
field[slicing].sum(axis=0) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]]
)
axes[i].set_xlabel('Mpc/h')
axes[i].set_ylabel('Mpc/h')
axes[i].set_title(f"{name} projection 0")
plt.tight_layout()
plt.show()