mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-05-15 04:21:12 +00:00
apply formating
This commit is contained in:
parent
c93894f561
commit
19011d0712
5 changed files with 22 additions and 15 deletions
|
@ -33,7 +33,8 @@ sharding = NamedSharding(mesh, P('x', 'y'))
|
|||
mesh_shape = [512, 512, 512]
|
||||
box_size = [500., 500., 1000.]
|
||||
halo_size = 64
|
||||
snapshots = jnp.linspace(0.1,1.,2)
|
||||
snapshots = jnp.linspace(0.1, 1., 2)
|
||||
|
||||
|
||||
@jax.jit
|
||||
def run_simulation(omega_c, sigma8):
|
||||
|
@ -89,9 +90,11 @@ print(f"[{rank}] Solver stats: {solver_stats}")
|
|||
|
||||
# Gather the results
|
||||
|
||||
pm_dict = {"initial_conditions": all_gather(initial_conditions),
|
||||
"lpt_displacements": all_gather(lpt_displacements),
|
||||
"solver_stats": solver_stats}
|
||||
pm_dict = {
|
||||
"initial_conditions": all_gather(initial_conditions),
|
||||
"lpt_displacements": all_gather(lpt_displacements),
|
||||
"solver_stats": solver_stats
|
||||
}
|
||||
|
||||
for i in range(len(ode_solutions)):
|
||||
sol = ode_solutions[i]
|
||||
|
|
|
@ -62,11 +62,9 @@ def plot_fields_single_projection(fields_dict, sum_over=None):
|
|||
slicing = tuple(slicing)
|
||||
|
||||
# Sum projection over axis 0 and plot
|
||||
axes[i].imshow(
|
||||
field[slicing].sum(axis=0) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[1], 0, field.shape[2]]
|
||||
)
|
||||
axes[i].imshow(field[slicing].sum(axis=0) + 1,
|
||||
cmap='magma',
|
||||
extent=[0, field.shape[1], 0, field.shape[2]])
|
||||
axes[i].set_xlabel('Mpc/h')
|
||||
axes[i].set_ylabel('Mpc/h')
|
||||
axes[i].set_title(f"{name} projection 0")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue