mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
update notebooks
This commit is contained in:
parent
b4fdb74660
commit
c93894f561
5 changed files with 341 additions and 174 deletions
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -30,11 +30,10 @@ devices = create_device_mesh(pdims)
|
|||
mesh = Mesh(devices, axis_names=('x', 'y'))
|
||||
sharding = NamedSharding(mesh, P('x', 'y'))
|
||||
|
||||
mesh_shape = [2024, 1024, 1024]
|
||||
box_size = [1024., 1024., 1024.]
|
||||
halo_size = 512
|
||||
snapshots = jnp.linspace(0.1, 1., 2)
|
||||
|
||||
mesh_shape = [512, 512, 512]
|
||||
box_size = [500., 500., 1000.]
|
||||
halo_size = 64
|
||||
snapshots = jnp.linspace(0.1,1.,2)
|
||||
|
||||
@jax.jit
|
||||
def run_simulation(omega_c, sigma8):
|
||||
|
@ -59,8 +58,7 @@ def run_simulation(omega_c, sigma8):
|
|||
# Initial displacement
|
||||
dx, p, _ = lpt(cosmo,
|
||||
initial_conditions,
|
||||
particles,
|
||||
0.1,
|
||||
a=0.1,
|
||||
halo_size=halo_size,
|
||||
sharding=sharding)
|
||||
|
||||
|
@ -90,15 +88,16 @@ print(f"[{rank}] Simulation completed")
|
|||
print(f"[{rank}] Solver stats: {solver_stats}")
|
||||
|
||||
# Gather the results
|
||||
initial_conditions = all_gather(initial_conditions)
|
||||
lpt_displacements = all_gather(lpt_displacements)
|
||||
ode_solutions = [all_gather(sol) for sol in ode_solutions]
|
||||
|
||||
pm_dict = {"initial_conditions": all_gather(initial_conditions),
|
||||
"lpt_displacements": all_gather(lpt_displacements),
|
||||
"solver_stats": solver_stats}
|
||||
|
||||
for i in range(len(ode_solutions)):
|
||||
sol = ode_solutions[i]
|
||||
pm_dict[f"ode_solution_{i}"] = all_gather(sol)
|
||||
|
||||
if rank == 0:
|
||||
np.savez("multihost_pm.npz",
|
||||
initial_conditions=initial_conditions,
|
||||
lpt_displacements=lpt_displacements,
|
||||
ode_solutions=ode_solutions,
|
||||
solver_stats=solver_stats)
|
||||
np.savez("multihost_pm.npz", **pm_dict)
|
||||
|
||||
print(f"[{rank}] Simulation results saved")
|
||||
|
|
|
@ -40,3 +40,36 @@ def plot_fields(fields_dict, sum_over=None):
|
|||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_fields_single_projection(fields_dict, sum_over=None):
|
||||
"""
|
||||
Plots a single projection (along axis 0) of 3D fields in one row,
|
||||
summing over the first `sum_over` elements along the 0-axis.
|
||||
|
||||
Args:
|
||||
- fields_dict: dictionary where keys are field names and values are 3D arrays
|
||||
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
|
||||
"""
|
||||
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||
nb_cols = len(fields_dict)
|
||||
fig, axes = plt.subplots(1, nb_cols, figsize=(5 * nb_cols, 5))
|
||||
|
||||
for i, (name, field) in enumerate(fields_dict.items()):
|
||||
# Define the slice for the 0-axis projection
|
||||
slicing = [slice(None)] * field.ndim
|
||||
slicing[0] = slice(None, sum_over)
|
||||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over axis 0 and plot
|
||||
axes[i].imshow(
|
||||
field[slicing].sum(axis=0) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[1], 0, field.shape[2]]
|
||||
)
|
||||
axes[i].set_xlabel('Mpc/h')
|
||||
axes[i].set_ylabel('Mpc/h')
|
||||
axes[i].set_title(f"{name} projection 0")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
|
Loading…
Add table
Reference in a new issue