simplified slice + slice diff + register colormaps

This commit is contained in:
Mayeul Aubin 2025-05-12 15:44:48 +02:00
parent 816e08b218
commit 9108ec488c
2 changed files with 138 additions and 65 deletions

36
analysis/colormaps.py Normal file
View file

@ -0,0 +1,36 @@
def register_colormaps(colormaps):
# Register cmasher
try:
import cmasher as cma
for name, cmap in cma.cm.cmap_d.items():
try:
colormaps.register(name=name, cmap=cmap)
except ValueError:
pass
except ImportError:
pass
# Register cmocean
try:
import cmocean as cmo
for name, cmap in cmo.cm.cmap_d.items():
try:
colormaps.register(name=name, cmap=cmap)
except ValueError:
pass
except ImportError:
pass
# Register cmcrameri
try:
import cmcrameri as cmc
for name, cmap in cmc.cm.cmaps.items():
try:
colormaps.register(name=name, cmap=cmap)
except ValueError:
pass
except ImportError:
pass

View file

@ -2,10 +2,21 @@ import numpy as np
import sys
sys.path.append('/home/aubin/Simbelmyne/sbmy_control/')
from cosmo_params import register_arguments_cosmo, parse_arguments_cosmo
fs = 18
fs_titles = fs -4
fs_titles = fs - 4
def add_ax_ticks(ax, ticks, tick_labels):
from matplotlib import ticker
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_xticklabels(tick_labels)
ax.set_yticklabels(tick_labels)
ax.set_xlabel('Mpc/h')
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
def plot_imshow_with_reference( data_list,
reference=None,
@ -13,7 +24,9 @@ def plot_imshow_with_reference( data_list,
vmin=None,
vmax=None,
L=None,
cmap='viridis'):
cmap='viridis',
cmap_diff='PuOr',
ref_label="Reference"):
"""
Plot the imshow of a list of 2D arrays with two rows: one for the data itself,
one for the data compared to a reference. Each row will have a common colorbar.
@ -25,7 +38,9 @@ def plot_imshow_with_reference( data_list,
- cmap: colormap to be used for plotting
"""
import matplotlib.pyplot as plt
from matplotlib import ticker
from colormaps import register_colormaps
register_colormaps(plt.colormaps)
if titles is None:
titles = [None for f in data_list]
@ -43,7 +58,7 @@ def plot_imshow_with_reference( data_list,
return np.linalg.norm(data-reference)/np.linalg.norm(reference)
n = len(data_list)
fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2))
fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2), squeeze = False)
if vmin is None or vmax is None:
vmin = min(np.quantile(data,0.01) for data in data_list)
@ -52,72 +67,88 @@ def plot_imshow_with_reference( data_list,
if reference is not None:
vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list)
vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list)
vmin_diff = min(vmin_diff, -vmax_diff)
vmax_diff = -vmin_diff
else:
vmin_diff = vmin
vmax_diff = vmax
if reference is not None:
# Plot the data itself
for i, data in enumerate(data_list):
im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
axes[0, i].set_title(titles[i], fontsize=fs_titles)
axes[0, i].set_xticks(ticks[i])
axes[0, i].set_yticks(ticks[i])
axes[0, i].set_xticklabels(tick_labels[i])
axes[0, i].set_yticklabels(tick_labels[i])
axes[0, i].set_xlabel('Mpc/h')
axes[0, i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
axes[0, i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
fig.colorbar(im, ax=axes[0, :], orientation='vertical')
# Plot the data itself
for i, data in enumerate(data_list):
im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
axes[0, i].set_title(titles[i], fontsize=fs_titles)
add_ax_ticks(axes[0, i], ticks[i], tick_labels[i])
fig.colorbar(im, ax=axes[0, :], orientation='vertical')
if reference is not None:
# Plot the data compared to the reference
for i, data in enumerate(data_list):
im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff)
axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles)
axes[1, i].set_xticks(ticks[i])
axes[1, i].set_yticks(ticks[i])
axes[1, i].set_xticklabels(tick_labels[i])
axes[1, i].set_yticklabels(tick_labels[i])
axes[1, i].set_xlabel('Mpc/h')
axes[1, i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
axes[1, i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
im = axes[1, i].imshow(data - reference, cmap=cmap_diff, origin='lower', vmin=vmin_diff, vmax=vmax_diff)
axes[1, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles)
add_ax_ticks(axes[1, i], ticks[i], tick_labels[i])
fig.colorbar(im, ax=axes[1, :], orientation='vertical')
# Add the score on the plots
for i, data in enumerate(data_list):
axes[1, i].text(0.5, 0.9, f"RMS: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white')
axes[1, i].text(0.5, 0.9, f"RMS: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='black')
# plt.tight_layout()
else:
if len(data_list) == 1:
data_list = data_list[0]
im = axes.imshow(data_list, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
axes.set_title(titles[0], fontsize=fs_titles)
axes.set_xticks(ticks[0])
axes.set_yticks(ticks[0])
axes.set_xticklabels(tick_labels[0])
axes.set_yticklabels(tick_labels[0])
axes.set_xlabel('Mpc/h')
axes.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
axes.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
fig.colorbar(im, ax=axes, orientation='vertical')
else:
for i, data in enumerate(data_list):
im = axes[i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
axes[i].set_title(titles[i], fontsize=fs_titles)
axes[i].set_xticks(ticks[i])
axes[i].set_yticks(ticks[i])
axes[i].set_xticklabels(tick_labels[i])
axes[i].set_yticklabels(tick_labels[i])
axes[i].set_xlabel('Mpc/h')
axes[i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
axes[i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
fig.colorbar(im, ax=axes[:], orientation='vertical')
return fig, axes
def plot_imshow_diff(data_list,
reference,
titles,
vmin=None,
vmax=None,
L=None,
cmap='viridis',
ref_label="Reference"):
import matplotlib.pyplot as plt
from colormaps import register_colormaps
register_colormaps(plt.colormaps)
if reference is None:
raise ValueError("Reference field is None")
if titles is None:
titles = [None for f in data_list]
if L is None:
L = [len(data) for data in data_list]
elif isinstance(L, int) or isinstance(L, float):
L = [L for data in data_list]
sep = 10 if L[0] < 50 else 20 if L[0] < 100 else 50 if L[0]<250 else 100 if L[0] < 500 else 200 if L[0] < 1000 else 500 if L[0] < 2500 else 1000
ticks = [np.arange(0, l+1, sep)*len(dat)/l for l, dat in zip(L,data_list)]
tick_labels = [np.arange(0, l+1, sep) for l in L]
def score(data, reference):
return np.linalg.norm(data-reference)/np.linalg.norm(reference)
n = len(data_list)
fig, axes = plt.subplots(1, n, figsize=(5 * n, 5), dpi=max(500, data_list[0].shape[0]//2), squeeze = False)
if vmin is None or vmax is None:
vmin = min(np.quantile(data-reference,0.01) for data in data_list)
vmax = max(np.quantile(data-reference,0.99) for data in data_list)
vmin = min(vmin, -vmax)
vmax = -vmin
# Plot the data compared to the reference
for i, data in enumerate(data_list):
im = axes[0, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
axes[0, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles)
add_ax_ticks(axes[0, i], ticks[i], tick_labels[i])
fig.colorbar(im, ax=axes[0, :], orientation='vertical')
return fig, axes
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser(description='Comparisons of fields slices.')
@ -134,16 +165,24 @@ if __name__ == "__main__":
parser.add_argument('-vmax', type=float, default=None, help='Maximum value for the colorbar.')
parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.')
parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.')
# register_arguments_cosmo(parser)
parser.add_argument('--diff', action='store_true', help='Plot only the difference with the reference field.')
parser.add_argument('--ref_label', type=str, default='Reference', help='Label for the reference field.')
parser.add_argument('--cmap_diff', type=str, default='PuOr', help='Colormap to be used for the difference plot.')
args = parser.parse_args()
from pysbmy.field import read_field
# from pysbmy.cosmology import d_plus
ref_label = args.ref_label
ref_field = read_field(args.directory+args.reference) if args.reference is not None else None
fields = [read_field(args.directory+f) for f in args.filenames]
fields = []
for k,f in enumerate(args.filenames):
if args.reference is not None and f == args.reference:
fields.append(ref_field) # Simply copy the reference field instead of reading it again
if args.labels is not None:
ref_label = args.labels[k] # Use the label of the field as the reference label
else:
fields.append(read_field(args.directory+f))
if args.index is None:
index = fields[0].N0//2
@ -157,10 +196,6 @@ if __name__ == "__main__":
case 0 | 'x':
reference = ref_field.data[index,:,:] if ref_field is not None else None
fields = [f.data[index,:,:] for f in fields]
# reference = ref_field.data[index,:,:]/d_plus(1e-3,ref_field.time,parse_arguments_cosmo(args))
# fields = [f.data[index,:,:]/d_plus(1e-3,f.time,parse_arguments_cosmo(args)) for f in fields]
# reference = ref_field.data[index,:,:]/d_plus(1e-3,0.05,parse_arguments_cosmo(args))
# fields = [f.data[index,:,:]/d_plus(1e-3,time,parse_arguments_cosmo(args)) for f,time in zip(fields,[0.05, 1.0])]
case 1 | 'y':
reference = ref_field.data[:,index,:] if ref_field is not None else None
fields = [f.data[:,index,:] for f in fields]
@ -175,8 +210,10 @@ if __name__ == "__main__":
fields = [np.log10(2.+f) for f in fields]
fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L)
if args.diff:
fig, axes = plot_imshow_diff(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap_diff, L=L, ref_label=ref_label)
else:
fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L, ref_label=ref_label, cmap_diff=args.cmap_diff)
fig.suptitle(args.title)
if args.output is not None: