From 06edf57e243b724f79b4387089fdc3d50dea33ee Mon Sep 17 00:00:00 2001 From: Mayeul Aubin Date: Mon, 24 Mar 2025 17:33:03 +0100 Subject: [PATCH] improving slice to have the ticks --- analysis/slices.py | 98 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 73 insertions(+), 25 deletions(-) diff --git a/analysis/slices.py b/analysis/slices.py index 0456e7b..581927d 100644 --- a/analysis/slices.py +++ b/analysis/slices.py @@ -1,13 +1,18 @@ import numpy as np +import sys +sys.path.append('/home/aubin/Simbelmyne/sbmy_control/') +from cosmo_params import register_arguments_cosmo, parse_arguments_cosmo + fs = 18 fs_titles = fs -4 def plot_imshow_with_reference( data_list, - reference, - titles, + reference=None, + titles=None, vmin=None, - vmax=None, + vmax=None, + L=None, cmap='viridis'): """ Plot the imshow of a list of 2D arrays with two rows: one for the data itself, @@ -23,38 +28,69 @@ def plot_imshow_with_reference( data_list, if titles is None: titles = [None for f in data_list] + + if L is None: + L = [len(data) for data in data_list] + elif isinstance(L, int) or isinstance(L, float): + L = [L for data in data_list] + + sep = 10 if L[0] < 100 else 20 if L[0] < 200 else 100 + ticks = [np.arange(0, l+1, sep)*len(dat)/l for l, dat in zip(L,data_list)] + tick_labels = [np.arange(0, l+1, sep) for l in L] def score(data, reference): return np.linalg.norm(data-reference)/np.linalg.norm(reference) n = len(data_list) - fig, axes = plt.subplots(2, n, figsize=(5 * n, 10)) + fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2)) if vmin is None or vmax is None: vmin = min(np.quantile(data,0.01) for data in data_list) vmax = max(np.quantile(data,0.99) for data in data_list) + + if reference is not None: vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list) vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list) else: vmin_diff = vmin vmax_diff = vmax - # Plot the data itself - for i, data in enumerate(data_list): - im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) - axes[0, i].set_title(titles[i], fontsize=fs_titles) - fig.colorbar(im, ax=axes[0, :], orientation='vertical') + if reference is not None: + # Plot the data itself + for i, data in enumerate(data_list): + im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) + axes[0, i].set_title(titles[i], fontsize=fs_titles) + axes[0, i].set_xticks(ticks[i]) + axes[0, i].set_yticks(ticks[i]) + axes[0, i].set_xticklabels(tick_labels[i]) + axes[0, i].set_yticklabels(tick_labels[i]) + axes[0, i].set_xlabel('Mpc/h') + fig.colorbar(im, ax=axes[0, :], orientation='vertical') - # Plot the data compared to the reference - for i, data in enumerate(data_list): - im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff) - axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles) - fig.colorbar(im, ax=axes[1, :], orientation='vertical') - - # Add the score on the plots - for i, data in enumerate(data_list): - axes[1, i].text(0.5, 0.9, f"Score: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white') - # plt.tight_layout() + # Plot the data compared to the reference + for i, data in enumerate(data_list): + im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff) + axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles) + axes[1, i].set_xticks(ticks[i]) + axes[1, i].set_yticks(ticks[i]) + axes[1, i].set_xticklabels(tick_labels[i]) + axes[1, i].set_yticklabels(tick_labels[i]) + axes[1, i].set_xlabel('Mpc/h') + fig.colorbar(im, ax=axes[1, :], orientation='vertical') + + # Add the score on the plots + for i, data in enumerate(data_list): + axes[1, i].text(0.5, 0.9, f"Score: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white') + # plt.tight_layout() + else: + for i, data in enumerate(data_list): + im = axes[i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) + axes[i].set_title(titles[i], fontsize=fs_titles) + axes[0, i].set_xticks(ticks[i]) + axes[0, i].set_yticks(ticks[i]) + axes[0, i].set_xticklabels(tick_labels[i]) + axes[0, i].set_yticklabels(tick_labels[i]) + fig.colorbar(im, ax=axes[:], orientation='vertical') return fig, axes @@ -76,36 +112,48 @@ if __name__ == "__main__": parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.') parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.') + register_arguments_cosmo(parser) + args = parser.parse_args() from pysbmy.field import read_field + from pysbmy.cosmology import d_plus - ref_field = read_field(args.directory+args.reference) + ref_field = read_field(args.directory+args.reference) if args.reference is not None else None fields = [read_field(args.directory+f) for f in args.filenames] if args.index is None: index = ref_field.N0//2 else: index=args.index + + # args.labels=[f"a={f.time:.2f}" for f in fields] + L = [f.L0 for f in fields] match args.axis: case 0 | 'x': - reference = ref_field.data[index,:,:] + reference = ref_field.data[index,:,:] if ref_field is not None else None fields = [f.data[index,:,:] for f in fields] + # reference = ref_field.data[index,:,:]/d_plus(1e-3,ref_field.time,parse_arguments_cosmo(args)) + # fields = [f.data[index,:,:]/d_plus(1e-3,f.time,parse_arguments_cosmo(args)) for f in fields] + # reference = ref_field.data[index,:,:]/d_plus(1e-3,0.05,parse_arguments_cosmo(args)) + # fields = [f.data[index,:,:]/d_plus(1e-3,time,parse_arguments_cosmo(args)) for f,time in zip(fields,[0.05, 1.0])] case 1 | 'y': - reference = ref_field.data[:,index,:] + reference = ref_field.data[:,index,:] if ref_field is not None else None fields = [f.data[:,index,:] for f in fields] case 2 | 'z': - reference = ref_field.data[:,:,index] + reference = ref_field.data[:,:,index] if ref_field is not None else None fields = [f.data[:,:,index] for f in fields] case _: raise ValueError(f"Wrong axis provided : {args.axis}") if args.log_scale: - reference = np.log10(2.+reference) + reference = np.log10(2.+reference) if ref_field is not None else None fields = [np.log10(2.+f) for f in fields] - fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap) + + + fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L) fig.suptitle(args.title) if args.output is not None: