From 9108ec488c354da43267c236834577e01907d29d Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Mon, 12 May 2025 15:44:48 +0200 Subject: [PATCH] simplified slice + slice diff + register colormaps --- analysis/colormaps.py | 36 +++++++++ analysis/slices.py | 167 ++++++++++++++++++++++++++---------------- 2 files changed, 138 insertions(+), 65 deletions(-) create mode 100644 analysis/colormaps.py diff --git a/analysis/colormaps.py b/analysis/colormaps.py new file mode 100644 index 0000000..54ddf3a --- /dev/null +++ b/analysis/colormaps.py @@ -0,0 +1,36 @@ + + +def register_colormaps(colormaps): + + # Register cmasher + try: + import cmasher as cma + for name, cmap in cma.cm.cmap_d.items(): + try: + colormaps.register(name=name, cmap=cmap) + except ValueError: + pass + except ImportError: + pass + + # Register cmocean + try: + import cmocean as cmo + for name, cmap in cmo.cm.cmap_d.items(): + try: + colormaps.register(name=name, cmap=cmap) + except ValueError: + pass + except ImportError: + pass + + # Register cmcrameri + try: + import cmcrameri as cmc + for name, cmap in cmc.cm.cmaps.items(): + try: + colormaps.register(name=name, cmap=cmap) + except ValueError: + pass + except ImportError: + pass diff --git a/analysis/slices.py b/analysis/slices.py index 4f01418..2febe16 100644 --- a/analysis/slices.py +++ b/analysis/slices.py @@ -2,10 +2,21 @@ import numpy as np import sys sys.path.append('/home/aubin/Simbelmyne/sbmy_control/') -from cosmo_params import register_arguments_cosmo, parse_arguments_cosmo fs = 18 -fs_titles = fs -4 +fs_titles = fs - 4 + +def add_ax_ticks(ax, ticks, tick_labels): + from matplotlib import ticker + ax.set_xticks(ticks) + ax.set_yticks(ticks) + ax.set_xticklabels(tick_labels) + ax.set_yticklabels(tick_labels) + ax.set_xlabel('Mpc/h') + ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + + def plot_imshow_with_reference( data_list, reference=None, @@ -13,7 +24,9 @@ def plot_imshow_with_reference( data_list, vmin=None, vmax=None, L=None, - cmap='viridis'): + cmap='viridis', + cmap_diff='PuOr', + ref_label="Reference"): """ Plot the imshow of a list of 2D arrays with two rows: one for the data itself, one for the data compared to a reference. Each row will have a common colorbar. @@ -25,7 +38,9 @@ def plot_imshow_with_reference( data_list, - cmap: colormap to be used for plotting """ import matplotlib.pyplot as plt - from matplotlib import ticker + + from colormaps import register_colormaps + register_colormaps(plt.colormaps) if titles is None: titles = [None for f in data_list] @@ -43,7 +58,7 @@ def plot_imshow_with_reference( data_list, return np.linalg.norm(data-reference)/np.linalg.norm(reference) n = len(data_list) - fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2)) + fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2), squeeze = False) if vmin is None or vmax is None: vmin = min(np.quantile(data,0.01) for data in data_list) @@ -52,72 +67,88 @@ def plot_imshow_with_reference( data_list, if reference is not None: vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list) vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list) + vmin_diff = min(vmin_diff, -vmax_diff) + vmax_diff = -vmin_diff else: vmin_diff = vmin vmax_diff = vmax - if reference is not None: - # Plot the data itself - for i, data in enumerate(data_list): - im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) - axes[0, i].set_title(titles[i], fontsize=fs_titles) - axes[0, i].set_xticks(ticks[i]) - axes[0, i].set_yticks(ticks[i]) - axes[0, i].set_xticklabels(tick_labels[i]) - axes[0, i].set_yticklabels(tick_labels[i]) - axes[0, i].set_xlabel('Mpc/h') - axes[0, i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) - axes[0, i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) - fig.colorbar(im, ax=axes[0, :], orientation='vertical') + # Plot the data itself + for i, data in enumerate(data_list): + im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) + axes[0, i].set_title(titles[i], fontsize=fs_titles) + add_ax_ticks(axes[0, i], ticks[i], tick_labels[i]) + fig.colorbar(im, ax=axes[0, :], orientation='vertical') + if reference is not None: # Plot the data compared to the reference for i, data in enumerate(data_list): - im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff) - axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles) - axes[1, i].set_xticks(ticks[i]) - axes[1, i].set_yticks(ticks[i]) - axes[1, i].set_xticklabels(tick_labels[i]) - axes[1, i].set_yticklabels(tick_labels[i]) - axes[1, i].set_xlabel('Mpc/h') - axes[1, i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) - axes[1, i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + im = axes[1, i].imshow(data - reference, cmap=cmap_diff, origin='lower', vmin=vmin_diff, vmax=vmax_diff) + axes[1, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles) + add_ax_ticks(axes[1, i], ticks[i], tick_labels[i]) fig.colorbar(im, ax=axes[1, :], orientation='vertical') # Add the score on the plots for i, data in enumerate(data_list): - axes[1, i].text(0.5, 0.9, f"RMS: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white') + axes[1, i].text(0.5, 0.9, f"RMS: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='black') # plt.tight_layout() - else: - - if len(data_list) == 1: - data_list = data_list[0] - im = axes.imshow(data_list, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) - axes.set_title(titles[0], fontsize=fs_titles) - axes.set_xticks(ticks[0]) - axes.set_yticks(ticks[0]) - axes.set_xticklabels(tick_labels[0]) - axes.set_yticklabels(tick_labels[0]) - axes.set_xlabel('Mpc/h') - axes.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) - axes.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) - fig.colorbar(im, ax=axes, orientation='vertical') - - else: - for i, data in enumerate(data_list): - im = axes[i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) - axes[i].set_title(titles[i], fontsize=fs_titles) - axes[i].set_xticks(ticks[i]) - axes[i].set_yticks(ticks[i]) - axes[i].set_xticklabels(tick_labels[i]) - axes[i].set_yticklabels(tick_labels[i]) - axes[i].set_xlabel('Mpc/h') - axes[i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) - axes[i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) - fig.colorbar(im, ax=axes[:], orientation='vertical') return fig, axes + +def plot_imshow_diff(data_list, + reference, + titles, + vmin=None, + vmax=None, + L=None, + cmap='viridis', + ref_label="Reference"): + + import matplotlib.pyplot as plt + + from colormaps import register_colormaps + register_colormaps(plt.colormaps) + + if reference is None: + raise ValueError("Reference field is None") + + if titles is None: + titles = [None for f in data_list] + + if L is None: + L = [len(data) for data in data_list] + elif isinstance(L, int) or isinstance(L, float): + L = [L for data in data_list] + + sep = 10 if L[0] < 50 else 20 if L[0] < 100 else 50 if L[0]<250 else 100 if L[0] < 500 else 200 if L[0] < 1000 else 500 if L[0] < 2500 else 1000 + ticks = [np.arange(0, l+1, sep)*len(dat)/l for l, dat in zip(L,data_list)] + tick_labels = [np.arange(0, l+1, sep) for l in L] + + def score(data, reference): + return np.linalg.norm(data-reference)/np.linalg.norm(reference) + + n = len(data_list) + fig, axes = plt.subplots(1, n, figsize=(5 * n, 5), dpi=max(500, data_list[0].shape[0]//2), squeeze = False) + + if vmin is None or vmax is None: + vmin = min(np.quantile(data-reference,0.01) for data in data_list) + vmax = max(np.quantile(data-reference,0.99) for data in data_list) + vmin = min(vmin, -vmax) + vmax = -vmin + + # Plot the data compared to the reference + for i, data in enumerate(data_list): + im = axes[0, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) + axes[0, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles) + add_ax_ticks(axes[0, i], ticks[i], tick_labels[i]) + fig.colorbar(im, ax=axes[0, :], orientation='vertical') + + return fig, axes + + + if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser(description='Comparisons of fields slices.') @@ -134,16 +165,24 @@ if __name__ == "__main__": parser.add_argument('-vmax', type=float, default=None, help='Maximum value for the colorbar.') parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.') parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.') - - # register_arguments_cosmo(parser) + parser.add_argument('--diff', action='store_true', help='Plot only the difference with the reference field.') + parser.add_argument('--ref_label', type=str, default='Reference', help='Label for the reference field.') + parser.add_argument('--cmap_diff', type=str, default='PuOr', help='Colormap to be used for the difference plot.') args = parser.parse_args() from pysbmy.field import read_field - # from pysbmy.cosmology import d_plus + ref_label = args.ref_label ref_field = read_field(args.directory+args.reference) if args.reference is not None else None - fields = [read_field(args.directory+f) for f in args.filenames] + fields = [] + for k,f in enumerate(args.filenames): + if args.reference is not None and f == args.reference: + fields.append(ref_field) # Simply copy the reference field instead of reading it again + if args.labels is not None: + ref_label = args.labels[k] # Use the label of the field as the reference label + else: + fields.append(read_field(args.directory+f)) if args.index is None: index = fields[0].N0//2 @@ -157,10 +196,6 @@ if __name__ == "__main__": case 0 | 'x': reference = ref_field.data[index,:,:] if ref_field is not None else None fields = [f.data[index,:,:] for f in fields] - # reference = ref_field.data[index,:,:]/d_plus(1e-3,ref_field.time,parse_arguments_cosmo(args)) - # fields = [f.data[index,:,:]/d_plus(1e-3,f.time,parse_arguments_cosmo(args)) for f in fields] - # reference = ref_field.data[index,:,:]/d_plus(1e-3,0.05,parse_arguments_cosmo(args)) - # fields = [f.data[index,:,:]/d_plus(1e-3,time,parse_arguments_cosmo(args)) for f,time in zip(fields,[0.05, 1.0])] case 1 | 'y': reference = ref_field.data[:,index,:] if ref_field is not None else None fields = [f.data[:,index,:] for f in fields] @@ -175,8 +210,10 @@ if __name__ == "__main__": fields = [np.log10(2.+f) for f in fields] - - fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L) + if args.diff: + fig, axes = plot_imshow_diff(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap_diff, L=L, ref_label=ref_label) + else: + fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L, ref_label=ref_label, cmap_diff=args.cmap_diff) fig.suptitle(args.title) if args.output is not None: