229 lines
9.6 KiB
Python
Executable file
229 lines
9.6 KiB
Python
Executable file
import numpy as np
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append('/home/aubin/Simbelmyne/sbmy_control/')
|
|
|
|
fs = 18
|
|
fs_titles = fs - 4
|
|
|
|
def add_ax_ticks(ax, ticks, tick_labels):
|
|
from matplotlib import ticker
|
|
ax.set_xticks(ticks)
|
|
ax.set_yticks(ticks)
|
|
ax.set_xticklabels([int(t) for t in tick_labels])
|
|
ax.set_yticklabels([int(t) for t in tick_labels])
|
|
ax.set_xlabel('Mpc/h')
|
|
# ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) # Does not work
|
|
# ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
|
|
|
|
|
|
|
def plot_imshow_with_reference( data_list,
|
|
reference=None,
|
|
titles=None,
|
|
vmin=None,
|
|
vmax=None,
|
|
L=None,
|
|
cmap='viridis',
|
|
cmap_diff='PuOr',
|
|
ref_label="Reference"):
|
|
"""
|
|
Plot the imshow of a list of 2D arrays with two rows: one for the data itself,
|
|
one for the data compared to a reference. Each row will have a common colorbar.
|
|
|
|
Parameters:
|
|
- data_list: list of 2D arrays to be plotted
|
|
- reference: 2D array to be used as reference for comparison
|
|
- titles: list of titles for each subplot
|
|
- cmap: colormap to be used for plotting
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
from colormaps import register_colormaps
|
|
register_colormaps(plt.colormaps)
|
|
|
|
if titles is None:
|
|
titles = [None for f in data_list]
|
|
|
|
if L is None:
|
|
L = [len(data) for data in data_list]
|
|
elif isinstance(L, int) or isinstance(L, float):
|
|
L = [L for data in data_list]
|
|
|
|
sep = 10 if L[0] < 50 else 20 if L[0] < 100 else 50 if L[0]<250 else 100 if L[0] < 500 else 200 if L[0] < 1000 else 500 if L[0] < 2500 else 1000
|
|
ticks = [np.arange(0, l+1, sep)*len(dat)/l for l, dat in zip(L,data_list)]
|
|
tick_labels = [np.arange(0, l+1, sep) for l in L]
|
|
|
|
def score(data, reference):
|
|
return np.linalg.norm(data-reference)/np.linalg.norm(reference)
|
|
|
|
n = len(data_list)
|
|
fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2), squeeze = False)
|
|
|
|
if vmin is None or vmax is None:
|
|
vmin = min(np.quantile(data,0.01) for data in data_list)
|
|
vmax = max(np.quantile(data,0.99) for data in data_list)
|
|
|
|
if reference is not None:
|
|
vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list)
|
|
vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list)
|
|
vmin_diff = min(vmin_diff, -vmax_diff)
|
|
vmax_diff = -vmin_diff
|
|
else:
|
|
vmin_diff = vmin
|
|
vmax_diff = vmax
|
|
|
|
# Plot the data itself
|
|
for i, data in enumerate(data_list):
|
|
im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
|
|
axes[0, i].set_title(titles[i], fontsize=fs_titles)
|
|
add_ax_ticks(axes[0, i], ticks[i], tick_labels[i])
|
|
fig.colorbar(im, ax=axes[0, :], orientation='vertical')
|
|
|
|
if reference is not None:
|
|
# Plot the data compared to the reference
|
|
for i, data in enumerate(data_list):
|
|
im = axes[1, i].imshow(data - reference, cmap=cmap_diff, origin='lower', vmin=vmin_diff, vmax=vmax_diff)
|
|
axes[1, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles)
|
|
add_ax_ticks(axes[1, i], ticks[i], tick_labels[i])
|
|
fig.colorbar(im, ax=axes[1, :], orientation='vertical')
|
|
|
|
# Add the score on the plots
|
|
for i, data in enumerate(data_list):
|
|
axes[1, i].text(0.5, 0.9, f"RMS: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='black')
|
|
# plt.tight_layout()
|
|
|
|
return fig, axes
|
|
|
|
|
|
|
|
def plot_imshow_diff(data_list,
|
|
reference,
|
|
titles,
|
|
vmin=None,
|
|
vmax=None,
|
|
L=None,
|
|
cmap='viridis',
|
|
ref_label="Reference"):
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
from colormaps import register_colormaps
|
|
register_colormaps(plt.colormaps)
|
|
|
|
if reference is None:
|
|
raise ValueError("Reference field is None")
|
|
|
|
if titles is None:
|
|
titles = [None for f in data_list]
|
|
|
|
if L is None:
|
|
L = [len(data) for data in data_list]
|
|
elif isinstance(L, int) or isinstance(L, float):
|
|
L = [L for data in data_list]
|
|
|
|
sep = 10 if L[0] < 50 else 20 if L[0] < 100 else 50 if L[0]<250 else 100 if L[0] < 500 else 200 if L[0] < 1000 else 500 if L[0] < 2500 else 1000
|
|
ticks = [np.arange(0, l+1, sep)*len(dat)/l for l, dat in zip(L,data_list)]
|
|
tick_labels = [np.arange(0, l+1, sep) for l in L]
|
|
|
|
def score(data, reference):
|
|
return np.linalg.norm(data-reference)/np.linalg.norm(reference)
|
|
|
|
n = len(data_list)
|
|
fig, axes = plt.subplots(1, n, figsize=(5 * n, 5), dpi=max(500, data_list[0].shape[0]//2), squeeze = False)
|
|
|
|
if vmin is None or vmax is None:
|
|
vmin = min(np.quantile(data-reference,0.01) for data in data_list)
|
|
vmax = max(np.quantile(data-reference,0.99) for data in data_list)
|
|
vmin = min(vmin, -vmax)
|
|
vmax = -vmin
|
|
|
|
# Plot the data compared to the reference
|
|
for i, data in enumerate(data_list):
|
|
im = axes[0, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
|
|
axes[0, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles)
|
|
add_ax_ticks(axes[0, i], ticks[i], tick_labels[i])
|
|
fig.colorbar(im, ax=axes[0, :], orientation='vertical')
|
|
|
|
return fig, axes
|
|
|
|
|
|
def console_main():
|
|
from argparse import ArgumentParser
|
|
parser = ArgumentParser(description='Comparisons of fields slices.')
|
|
|
|
parser.add_argument('-a','--axis', type=int, default=0, help='Axis along which the slices will be taken.')
|
|
parser.add_argument('-i','--index', type=int, default=None, help='Index of the slice along the axis.')
|
|
parser.add_argument('-d', '--directory', type=str, required=True, help='Directory containing the fields files.')
|
|
parser.add_argument('-ref', '--reference', type=str, default=None, help='Reference field file.')
|
|
parser.add_argument('-f', '--filenames', type=str, nargs='+', required=True, help='Field files to be plotted.')
|
|
parser.add_argument('-o', '--output', type=str, default=None, help='Output plot file name.')
|
|
parser.add_argument('-l', '--labels', type=str, nargs='+', default=None, help='Labels for each field.')
|
|
parser.add_argument('-c', '--cmap', type=str, default='viridis', help='Colormap to be used for plotting.')
|
|
parser.add_argument('-vmin', type=float, default=None, help='Minimum value for the colorbar.')
|
|
parser.add_argument('-vmax', type=float, default=None, help='Maximum value for the colorbar.')
|
|
parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.')
|
|
parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.')
|
|
parser.add_argument('--diff', action='store_true', help='Plot only the difference with the reference field.')
|
|
parser.add_argument('--ref_label', type=str, default='Reference', help='Label for the reference field.')
|
|
parser.add_argument('--cmap_diff', type=str, default='PuOr', help='Colormap to be used for the difference plot.')
|
|
|
|
args = parser.parse_args()
|
|
|
|
from pysbmy.field import read_field
|
|
|
|
ref_label = args.ref_label
|
|
ref_field = read_field(args.directory+args.reference) if args.reference is not None else None
|
|
fields = []
|
|
for k,f in enumerate(args.filenames):
|
|
if not os.path.exists(args.directory+f):
|
|
raise FileNotFoundError(f"File {args.directory+f} does not exist.")
|
|
if args.reference is not None and f == args.reference:
|
|
fields.append(ref_field) # Simply copy the reference field instead of reading it again
|
|
if args.labels is not None:
|
|
ref_label = args.labels[k] # Use the label of the field as the reference label
|
|
else:
|
|
fields.append(read_field(args.directory+f))
|
|
|
|
if args.index is None:
|
|
index = fields[0].N0//2
|
|
else:
|
|
index=args.index
|
|
|
|
# args.labels=[f"a={f.time:.2f}" for f in fields]
|
|
L = [f.L0 for f in fields]
|
|
|
|
match args.axis:
|
|
case 0 | 'x':
|
|
reference = ref_field.data[index,:,:] if ref_field is not None else None
|
|
fields = [f.data[index,:,:] for f in fields]
|
|
case 1 | 'y':
|
|
reference = ref_field.data[:,index,:] if ref_field is not None else None
|
|
fields = [f.data[:,index,:] for f in fields]
|
|
case 2 | 'z':
|
|
reference = ref_field.data[:,:,index] if ref_field is not None else None
|
|
fields = [f.data[:,:,index] for f in fields]
|
|
case _:
|
|
raise ValueError(f"Wrong axis provided : {args.axis}")
|
|
|
|
if args.log_scale:
|
|
reference = np.log10(2.+reference) if ref_field is not None else None
|
|
fields = [np.log10(2.+f) for f in fields]
|
|
|
|
|
|
if args.diff:
|
|
fig, axes = plot_imshow_diff(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap_diff, L=L, ref_label=ref_label)
|
|
else:
|
|
fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L, ref_label=ref_label, cmap_diff=args.cmap_diff)
|
|
fig.suptitle(args.title)
|
|
|
|
if args.output is not None:
|
|
fig.savefig(args.output,bbox_inches='tight')
|
|
else:
|
|
fig.savefig(args.directory+'slices.jpg',bbox_inches='tight')
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
console_main()
|