Put plotting utils in package

This commit is contained in:
Wassim Kabalan 2024-12-05 18:22:34 +01:00
parent b32014b7ea
commit c1b276d224
2 changed files with 14 additions and 124 deletions

View file

@ -42,7 +42,12 @@ def plot_fields(fields_dict, sum_over=None):
plt.show()
def plot_fields_single_projection(fields_dict, sum_over=None, project_axis=0):
def plot_fields_single_projection(fields_dict,
sum_over=None,
project_axis=0,
vmin=None,
vmax=None,
colorbar=False):
"""
Plots a single projection (along axis 0) of 3D fields in a grid,
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
@ -70,12 +75,17 @@ def plot_fields_single_projection(fields_dict, sum_over=None, project_axis=0):
slicing = tuple(slicing)
# Sum projection over axis 0 and plot
axes[row, col].imshow(field[slicing].sum(axis=project_axis) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]])
a = axes[row,
col].imshow(field[slicing].sum(axis=project_axis) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]],
vmin=vmin,
vmax=vmax)
axes[row, col].set_xlabel('Mpc/h')
axes[row, col].set_ylabel('Mpc/h')
axes[row, col].set_title(f"{name} projection 0")
if colorbar:
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
# Remove any empty subplots
for j in range(i + 1, nb_rows * nb_cols):