import numpy as np fs = 18 fs_titles = fs -4 def plot_imshow_with_reference( data_list, reference, titles, vmin=None, vmax=None, cmap='viridis'): """ Plot the imshow of a list of 2D arrays with two rows: one for the data itself, one for the data compared to a reference. Each row will have a common colorbar. Parameters: - data_list: list of 2D arrays to be plotted - reference: 2D array to be used as reference for comparison - titles: list of titles for each subplot - cmap: colormap to be used for plotting """ import matplotlib.pyplot as plt if titles is None: titles = [None for f in data_list] def score(data, reference): return np.linalg.norm(data-reference)/np.linalg.norm(reference) n = len(data_list) fig, axes = plt.subplots(2, n, figsize=(5 * n, 10)) if vmin is None or vmax is None: vmin = min(np.quantile(data,0.01) for data in data_list) vmax = max(np.quantile(data,0.99) for data in data_list) vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list) vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list) else: vmin_diff = vmin vmax_diff = vmax # Plot the data itself for i, data in enumerate(data_list): im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) axes[0, i].set_title(titles[i], fontsize=fs_titles) fig.colorbar(im, ax=axes[0, :], orientation='vertical') # Plot the data compared to the reference for i, data in enumerate(data_list): im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff) axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles) fig.colorbar(im, ax=axes[1, :], orientation='vertical') # Add the score on the plots for i, data in enumerate(data_list): axes[1, i].text(0.5, 0.9, f"Score: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white') # plt.tight_layout() return fig, axes if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser(description='Comparisons of fields slices.') parser.add_argument('-a','--axis', type=int, default=0, help='Axis along which the slices will be taken.') parser.add_argument('-i','--index', type=int, default=None, help='Index of the slice along the axis.') parser.add_argument('-d', '--directory', type=str, required=True, help='Directory containing the fields files.') parser.add_argument('-ref', '--reference', type=str, default=None, help='Reference field file.') parser.add_argument('-f', '--filenames', type=str, nargs='+', required=True, help='Field files to be plotted.') parser.add_argument('-o', '--output', type=str, default=None, help='Output plot file name.') parser.add_argument('-l', '--labels', type=str, nargs='+', default=None, help='Labels for each field.') parser.add_argument('-c', '--cmap', type=str, default='viridis', help='Colormap to be used for plotting.') parser.add_argument('-vmin', type=float, default=None, help='Minimum value for the colorbar.') parser.add_argument('-vmax', type=float, default=None, help='Maximum value for the colorbar.') parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.') parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.') args = parser.parse_args() from pysbmy.field import read_field ref_field = read_field(args.directory+args.reference) fields = [read_field(args.directory+f) for f in args.filenames] if args.index is None: index = ref_field.N0//2 else: index=args.index match args.axis: case 0 | 'x': reference = ref_field.data[index,:,:] fields = [f.data[index,:,:] for f in fields] case 1 | 'y': reference = ref_field.data[:,index,:] fields = [f.data[:,index,:] for f in fields] case 2 | 'z': reference = ref_field.data[:,:,index] fields = [f.data[:,:,index] for f in fields] case _: raise ValueError(f"Wrong axis provided : {args.axis}") if args.log_scale: reference = np.log10(2.+reference) fields = [np.log10(2.+f) for f in fields] fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap) fig.suptitle(args.title) if args.output is not None: fig.savefig(args.output) else: fig.savefig(args.directory+'slices.png')