added analysis
This commit is contained in:
parent
c9bddb08d7
commit
72ce9a3b99
3 changed files with 449 additions and 0 deletions
114
analysis/slices.py
Normal file
114
analysis/slices.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
import numpy as np
|
||||
|
||||
fs = 18
|
||||
fs_titles = fs -4
|
||||
|
||||
def plot_imshow_with_reference( data_list,
|
||||
reference,
|
||||
titles,
|
||||
vmin=None,
|
||||
vmax=None,
|
||||
cmap='viridis'):
|
||||
"""
|
||||
Plot the imshow of a list of 2D arrays with two rows: one for the data itself,
|
||||
one for the data compared to a reference. Each row will have a common colorbar.
|
||||
|
||||
Parameters:
|
||||
- data_list: list of 2D arrays to be plotted
|
||||
- reference: 2D array to be used as reference for comparison
|
||||
- titles: list of titles for each subplot
|
||||
- cmap: colormap to be used for plotting
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
if titles is None:
|
||||
titles = [None for f in data_list]
|
||||
|
||||
def score(data, reference):
|
||||
return np.linalg.norm(data-reference)/np.linalg.norm(reference)
|
||||
|
||||
n = len(data_list)
|
||||
fig, axes = plt.subplots(2, n, figsize=(5 * n, 10))
|
||||
|
||||
if vmin is None or vmax is None:
|
||||
vmin = min(np.quantile(data,0.01) for data in data_list)
|
||||
vmax = max(np.quantile(data,0.99) for data in data_list)
|
||||
vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list)
|
||||
vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list)
|
||||
else:
|
||||
vmin_diff = vmin
|
||||
vmax_diff = vmax
|
||||
|
||||
# Plot the data itself
|
||||
for i, data in enumerate(data_list):
|
||||
im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
|
||||
axes[0, i].set_title(titles[i], fontsize=fs_titles)
|
||||
fig.colorbar(im, ax=axes[0, :], orientation='vertical')
|
||||
|
||||
# Plot the data compared to the reference
|
||||
for i, data in enumerate(data_list):
|
||||
im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff)
|
||||
axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles)
|
||||
fig.colorbar(im, ax=axes[1, :], orientation='vertical')
|
||||
|
||||
# Add the score on the plots
|
||||
for i, data in enumerate(data_list):
|
||||
axes[1, i].text(0.5, 0.9, f"Score: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white')
|
||||
# plt.tight_layout()
|
||||
|
||||
return fig, axes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentParser
|
||||
parser = ArgumentParser(description='Comparisons of fields slices.')
|
||||
|
||||
parser.add_argument('-a','--axis', type=int, default=0, help='Axis along which the slices will be taken.')
|
||||
parser.add_argument('-i','--index', type=int, default=None, help='Index of the slice along the axis.')
|
||||
parser.add_argument('-d', '--directory', type=str, required=True, help='Directory containing the fields files.')
|
||||
parser.add_argument('-ref', '--reference', type=str, default=None, help='Reference field file.')
|
||||
parser.add_argument('-f', '--filenames', type=str, nargs='+', required=True, help='Field files to be plotted.')
|
||||
parser.add_argument('-o', '--output', type=str, default=None, help='Output plot file name.')
|
||||
parser.add_argument('-l', '--labels', type=str, nargs='+', default=None, help='Labels for each field.')
|
||||
parser.add_argument('-c', '--cmap', type=str, default='viridis', help='Colormap to be used for plotting.')
|
||||
parser.add_argument('-vmin', type=float, default=None, help='Minimum value for the colorbar.')
|
||||
parser.add_argument('-vmax', type=float, default=None, help='Maximum value for the colorbar.')
|
||||
parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.')
|
||||
parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
from pysbmy.field import read_field
|
||||
|
||||
ref_field = read_field(args.directory+args.reference)
|
||||
fields = [read_field(args.directory+f) for f in args.filenames]
|
||||
|
||||
if args.index is None:
|
||||
index = ref_field.N0//2
|
||||
else:
|
||||
index=args.index
|
||||
|
||||
match args.axis:
|
||||
case 0 | 'x':
|
||||
reference = ref_field.data[index,:,:]
|
||||
fields = [f.data[index,:,:] for f in fields]
|
||||
case 1 | 'y':
|
||||
reference = ref_field.data[:,index,:]
|
||||
fields = [f.data[:,index,:] for f in fields]
|
||||
case 2 | 'z':
|
||||
reference = ref_field.data[:,:,index]
|
||||
fields = [f.data[:,:,index] for f in fields]
|
||||
case _:
|
||||
raise ValueError(f"Wrong axis provided : {args.axis}")
|
||||
|
||||
if args.log_scale:
|
||||
reference = np.log10(2.+reference)
|
||||
fields = [np.log10(2.+f) for f in fields]
|
||||
|
||||
fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap)
|
||||
fig.suptitle(args.title)
|
||||
|
||||
if args.output is not None:
|
||||
fig.savefig(args.output)
|
||||
else:
|
||||
fig.savefig(args.directory+'slices.png')
|
Loading…
Add table
Add a link
Reference in a new issue