improving slice to have the ticks

This commit is contained in:
Mayeul Aubin 2025-03-24 17:33:03 +01:00
parent 536d3df365
commit 06edf57e24

View file

@ -1,13 +1,18 @@
import numpy as np import numpy as np
import sys
sys.path.append('/home/aubin/Simbelmyne/sbmy_control/')
from cosmo_params import register_arguments_cosmo, parse_arguments_cosmo
fs = 18 fs = 18
fs_titles = fs -4 fs_titles = fs -4
def plot_imshow_with_reference( data_list, def plot_imshow_with_reference( data_list,
reference, reference=None,
titles, titles=None,
vmin=None, vmin=None,
vmax=None, vmax=None,
L=None,
cmap='viridis'): cmap='viridis'):
""" """
Plot the imshow of a list of 2D arrays with two rows: one for the data itself, Plot the imshow of a list of 2D arrays with two rows: one for the data itself,
@ -24,37 +29,68 @@ def plot_imshow_with_reference( data_list,
if titles is None: if titles is None:
titles = [None for f in data_list] titles = [None for f in data_list]
if L is None:
L = [len(data) for data in data_list]
elif isinstance(L, int) or isinstance(L, float):
L = [L for data in data_list]
sep = 10 if L[0] < 100 else 20 if L[0] < 200 else 100
ticks = [np.arange(0, l+1, sep)*len(dat)/l for l, dat in zip(L,data_list)]
tick_labels = [np.arange(0, l+1, sep) for l in L]
def score(data, reference): def score(data, reference):
return np.linalg.norm(data-reference)/np.linalg.norm(reference) return np.linalg.norm(data-reference)/np.linalg.norm(reference)
n = len(data_list) n = len(data_list)
fig, axes = plt.subplots(2, n, figsize=(5 * n, 10)) fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2))
if vmin is None or vmax is None: if vmin is None or vmax is None:
vmin = min(np.quantile(data,0.01) for data in data_list) vmin = min(np.quantile(data,0.01) for data in data_list)
vmax = max(np.quantile(data,0.99) for data in data_list) vmax = max(np.quantile(data,0.99) for data in data_list)
if reference is not None:
vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list) vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list)
vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list) vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list)
else: else:
vmin_diff = vmin vmin_diff = vmin
vmax_diff = vmax vmax_diff = vmax
if reference is not None:
# Plot the data itself # Plot the data itself
for i, data in enumerate(data_list): for i, data in enumerate(data_list):
im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
axes[0, i].set_title(titles[i], fontsize=fs_titles) axes[0, i].set_title(titles[i], fontsize=fs_titles)
axes[0, i].set_xticks(ticks[i])
axes[0, i].set_yticks(ticks[i])
axes[0, i].set_xticklabels(tick_labels[i])
axes[0, i].set_yticklabels(tick_labels[i])
axes[0, i].set_xlabel('Mpc/h')
fig.colorbar(im, ax=axes[0, :], orientation='vertical') fig.colorbar(im, ax=axes[0, :], orientation='vertical')
# Plot the data compared to the reference # Plot the data compared to the reference
for i, data in enumerate(data_list): for i, data in enumerate(data_list):
im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff) im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff)
axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles) axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles)
axes[1, i].set_xticks(ticks[i])
axes[1, i].set_yticks(ticks[i])
axes[1, i].set_xticklabels(tick_labels[i])
axes[1, i].set_yticklabels(tick_labels[i])
axes[1, i].set_xlabel('Mpc/h')
fig.colorbar(im, ax=axes[1, :], orientation='vertical') fig.colorbar(im, ax=axes[1, :], orientation='vertical')
# Add the score on the plots # Add the score on the plots
for i, data in enumerate(data_list): for i, data in enumerate(data_list):
axes[1, i].text(0.5, 0.9, f"Score: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white') axes[1, i].text(0.5, 0.9, f"Score: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white')
# plt.tight_layout() # plt.tight_layout()
else:
for i, data in enumerate(data_list):
im = axes[i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
axes[i].set_title(titles[i], fontsize=fs_titles)
axes[0, i].set_xticks(ticks[i])
axes[0, i].set_yticks(ticks[i])
axes[0, i].set_xticklabels(tick_labels[i])
axes[0, i].set_yticklabels(tick_labels[i])
fig.colorbar(im, ax=axes[:], orientation='vertical')
return fig, axes return fig, axes
@ -76,11 +112,14 @@ if __name__ == "__main__":
parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.') parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.')
parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.') parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.')
register_arguments_cosmo(parser)
args = parser.parse_args() args = parser.parse_args()
from pysbmy.field import read_field from pysbmy.field import read_field
from pysbmy.cosmology import d_plus
ref_field = read_field(args.directory+args.reference) ref_field = read_field(args.directory+args.reference) if args.reference is not None else None
fields = [read_field(args.directory+f) for f in args.filenames] fields = [read_field(args.directory+f) for f in args.filenames]
if args.index is None: if args.index is None:
@ -88,24 +127,33 @@ if __name__ == "__main__":
else: else:
index=args.index index=args.index
# args.labels=[f"a={f.time:.2f}" for f in fields]
L = [f.L0 for f in fields]
match args.axis: match args.axis:
case 0 | 'x': case 0 | 'x':
reference = ref_field.data[index,:,:] reference = ref_field.data[index,:,:] if ref_field is not None else None
fields = [f.data[index,:,:] for f in fields] fields = [f.data[index,:,:] for f in fields]
# reference = ref_field.data[index,:,:]/d_plus(1e-3,ref_field.time,parse_arguments_cosmo(args))
# fields = [f.data[index,:,:]/d_plus(1e-3,f.time,parse_arguments_cosmo(args)) for f in fields]
# reference = ref_field.data[index,:,:]/d_plus(1e-3,0.05,parse_arguments_cosmo(args))
# fields = [f.data[index,:,:]/d_plus(1e-3,time,parse_arguments_cosmo(args)) for f,time in zip(fields,[0.05, 1.0])]
case 1 | 'y': case 1 | 'y':
reference = ref_field.data[:,index,:] reference = ref_field.data[:,index,:] if ref_field is not None else None
fields = [f.data[:,index,:] for f in fields] fields = [f.data[:,index,:] for f in fields]
case 2 | 'z': case 2 | 'z':
reference = ref_field.data[:,:,index] reference = ref_field.data[:,:,index] if ref_field is not None else None
fields = [f.data[:,:,index] for f in fields] fields = [f.data[:,:,index] for f in fields]
case _: case _:
raise ValueError(f"Wrong axis provided : {args.axis}") raise ValueError(f"Wrong axis provided : {args.axis}")
if args.log_scale: if args.log_scale:
reference = np.log10(2.+reference) reference = np.log10(2.+reference) if ref_field is not None else None
fields = [np.log10(2.+f) for f in fields] fields = [np.log10(2.+f) for f in fields]
fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap)
fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L)
fig.suptitle(args.title) fig.suptitle(args.title)
if args.output is not None: if args.output is not None: