Compare commits
No commits in common. "9108ec488c354da43267c236834577e01907d29d" and "f63a20bf5b8b3d4b440f8d057cb02f622849af37" have entirely different histories.
9108ec488c
...
f63a20bf5b
3 changed files with 74 additions and 183 deletions
|
@ -1,36 +0,0 @@
|
||||||
|
|
||||||
|
|
||||||
def register_colormaps(colormaps):
|
|
||||||
|
|
||||||
# Register cmasher
|
|
||||||
try:
|
|
||||||
import cmasher as cma
|
|
||||||
for name, cmap in cma.cm.cmap_d.items():
|
|
||||||
try:
|
|
||||||
colormaps.register(name=name, cmap=cmap)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Register cmocean
|
|
||||||
try:
|
|
||||||
import cmocean as cmo
|
|
||||||
for name, cmap in cmo.cm.cmap_d.items():
|
|
||||||
try:
|
|
||||||
colormaps.register(name=name, cmap=cmap)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Register cmcrameri
|
|
||||||
try:
|
|
||||||
import cmcrameri as cmc
|
|
||||||
for name, cmap in cmc.cm.cmaps.items():
|
|
||||||
try:
|
|
||||||
colormaps.register(name=name, cmap=cmap)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
|
@ -5,24 +5,6 @@ kmax = 2e0
|
||||||
Nk = 50
|
Nk = 50
|
||||||
AliasingCorr=False
|
AliasingCorr=False
|
||||||
|
|
||||||
def crop_field(field, Ncrop):
|
|
||||||
|
|
||||||
if Ncrop is None or Ncrop == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
elif Ncrop > 0:
|
|
||||||
field.data = field.data[Ncrop:-Ncrop, Ncrop:-Ncrop, Ncrop:-Ncrop]
|
|
||||||
d0 = field.L0/field.N0
|
|
||||||
d1 = field.L1/field.N1
|
|
||||||
d2 = field.L2/field.N2
|
|
||||||
field.N0 -= 2*Ncrop
|
|
||||||
field.N1 -= 2*Ncrop
|
|
||||||
field.N2 -= 2*Ncrop
|
|
||||||
field.L0 = field.N0*d0
|
|
||||||
field.L1 = field.N1*d1
|
|
||||||
field.L2 = field.N2*d2
|
|
||||||
|
|
||||||
|
|
||||||
def get_power_spectrum(field, kmin=kmin, kmax=kmax, Nk=Nk, G=None):
|
def get_power_spectrum(field, kmin=kmin, kmax=kmax, Nk=Nk, G=None):
|
||||||
from pysbmy.power import PowerSpectrum
|
from pysbmy.power import PowerSpectrum
|
||||||
from pysbmy.fft import FourierGrid
|
from pysbmy.fft import FourierGrid
|
||||||
|
@ -109,8 +91,7 @@ def plot_power_spectra(filenames,
|
||||||
figsize=(8,4),
|
figsize=(8,4),
|
||||||
dpi=300,
|
dpi=300,
|
||||||
ax=None,
|
ax=None,
|
||||||
fig=None,
|
fig=None,):
|
||||||
Ncrop=None,):
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from pysbmy.field import read_field
|
from pysbmy.field import read_field
|
||||||
|
@ -129,7 +110,6 @@ def plot_power_spectra(filenames,
|
||||||
|
|
||||||
for i, filename in enumerate(filenames):
|
for i, filename in enumerate(filenames):
|
||||||
field = read_field(filename)
|
field = read_field(filename)
|
||||||
crop_field(field, Ncrop)
|
|
||||||
_, G, k, _ = add_power_spectrum_to_plot(ax=ax,
|
_, G, k, _ = add_power_spectrum_to_plot(ax=ax,
|
||||||
field=field,
|
field=field,
|
||||||
Pk_ref=Pk_ref,
|
Pk_ref=Pk_ref,
|
||||||
|
@ -148,7 +128,7 @@ def plot_power_spectra(filenames,
|
||||||
ax.set_ylim(ylims)
|
ax.set_ylim(ylims)
|
||||||
if yticks is not None:
|
if yticks is not None:
|
||||||
ax.set_yticks(yticks)
|
ax.set_yticks(yticks)
|
||||||
ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]', labelpad=-0)
|
ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]', labelpad=-10)
|
||||||
|
|
||||||
if Pk_ref is not None:
|
if Pk_ref is not None:
|
||||||
ax.set_ylabel(r'$P(k)/P_\mathrm{ref}(k)$')
|
ax.set_ylabel(r'$P(k)/P_\mathrm{ref}(k)$')
|
||||||
|
@ -183,9 +163,7 @@ def plot_cross_correlations(filenames_A,
|
||||||
figsize=(8,4),
|
figsize=(8,4),
|
||||||
dpi=300,
|
dpi=300,
|
||||||
ax=None,
|
ax=None,
|
||||||
fig=None,
|
fig=None,):
|
||||||
Ncrop=None,
|
|
||||||
):
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from pysbmy.field import read_field
|
from pysbmy.field import read_field
|
||||||
|
@ -203,11 +181,9 @@ def plot_cross_correlations(filenames_A,
|
||||||
markers = [None for f in filenames_A]
|
markers = [None for f in filenames_A]
|
||||||
|
|
||||||
field_B = read_field(filename_B)
|
field_B = read_field(filename_B)
|
||||||
crop_field(field_B, Ncrop)
|
|
||||||
|
|
||||||
for i, filename_A in enumerate(filenames_A):
|
for i, filename_A in enumerate(filenames_A):
|
||||||
field_A = read_field(filename_A)
|
field_A = read_field(filename_A)
|
||||||
crop_field(field_A, Ncrop)
|
|
||||||
_, G, k, _ = add_cross_correlations_to_plot(ax=ax,
|
_, G, k, _ = add_cross_correlations_to_plot(ax=ax,
|
||||||
field_A=field_A,
|
field_A=field_A,
|
||||||
field_B=field_B,
|
field_B=field_B,
|
||||||
|
@ -226,7 +202,7 @@ def plot_cross_correlations(filenames_A,
|
||||||
ax.set_ylim(ylims)
|
ax.set_ylim(ylims)
|
||||||
if yticks is not None:
|
if yticks is not None:
|
||||||
ax.set_yticks(yticks)
|
ax.set_yticks(yticks)
|
||||||
ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]', labelpad=-0)
|
ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]', labelpad=-10)
|
||||||
ax.set_ylabel('$R(k)$')
|
ax.set_ylabel('$R(k)$')
|
||||||
|
|
||||||
if bound1 is not None:
|
if bound1 is not None:
|
||||||
|
@ -279,7 +255,6 @@ if __name__ == "__main__":
|
||||||
parser.add_argument('-t','--title', type=str, default=None, help='Title of the plot.')
|
parser.add_argument('-t','--title', type=str, default=None, help='Title of the plot.')
|
||||||
parser.add_argument('-yrp', '--ylim_power', type=float, nargs=2, default=[0.9,1.1], help='Y-axis limits.')
|
parser.add_argument('-yrp', '--ylim_power', type=float, nargs=2, default=[0.9,1.1], help='Y-axis limits.')
|
||||||
parser.add_argument('-yrc', '--ylim_corr', type=float, nargs=2, default=[0.99,1.001], help='Y-axis limits.')
|
parser.add_argument('-yrc', '--ylim_corr', type=float, nargs=2, default=[0.99,1.001], help='Y-axis limits.')
|
||||||
parser.add_argument('--crop', type=int, default=None, help='Remove the outter N pixels of the fields.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -289,9 +264,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
if args.reference is not None:
|
if args.reference is not None:
|
||||||
from pysbmy.field import read_field
|
from pysbmy.field import read_field
|
||||||
F_ref = read_field(args.directory+args.reference)
|
G, _, Pk_ref = get_power_spectrum(read_field(args.directory+args.reference), kmin=kmin, kmax=kmax, Nk=Nk)
|
||||||
crop_field(F_ref, args.crop)
|
|
||||||
G, _, Pk_ref = get_power_spectrum(F_ref, kmin=kmin, kmax=kmax, Nk=Nk)
|
|
||||||
else:
|
else:
|
||||||
Pk_ref = None
|
Pk_ref = None
|
||||||
G = None
|
G = None
|
||||||
|
@ -306,7 +279,6 @@ if __name__ == "__main__":
|
||||||
if args.power_spectrum and args.cross_correlation:
|
if args.power_spectrum and args.cross_correlation:
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
fig, axes = plt.subplots(2, 1, figsize=(8,8))
|
fig, axes = plt.subplots(2, 1, figsize=(8,8))
|
||||||
fig.subplots_adjust(hspace=0.3)
|
|
||||||
plot_power_spectra(filenames=filenames,
|
plot_power_spectra(filenames=filenames,
|
||||||
labels=args.labels,
|
labels=args.labels,
|
||||||
colors=args.colors,
|
colors=args.colors,
|
||||||
|
@ -322,9 +294,7 @@ if __name__ == "__main__":
|
||||||
kmax=kmax,
|
kmax=kmax,
|
||||||
Nk=Nk,
|
Nk=Nk,
|
||||||
ax=axes[0],
|
ax=axes[0],
|
||||||
fig=fig,
|
fig=fig)
|
||||||
Ncrop=args.crop,
|
|
||||||
)
|
|
||||||
|
|
||||||
plot_cross_correlations(filenames_A=filenames,
|
plot_cross_correlations(filenames_A=filenames,
|
||||||
filename_B=args.directory+args.reference,
|
filename_B=args.directory+args.reference,
|
||||||
|
@ -341,9 +311,7 @@ if __name__ == "__main__":
|
||||||
kmax=kmax,
|
kmax=kmax,
|
||||||
Nk=Nk,
|
Nk=Nk,
|
||||||
ax=axes[1],
|
ax=axes[1],
|
||||||
fig=fig,
|
fig=fig)
|
||||||
Ncrop=args.crop,
|
|
||||||
)
|
|
||||||
|
|
||||||
axes[1].legend(loc='lower left')
|
axes[1].legend(loc='lower left')
|
||||||
axes[0].set_title("Power Spectrum")
|
axes[0].set_title("Power Spectrum")
|
||||||
|
@ -366,9 +334,7 @@ if __name__ == "__main__":
|
||||||
bound2=0.02,
|
bound2=0.02,
|
||||||
kmin=kmin,
|
kmin=kmin,
|
||||||
kmax=kmax,
|
kmax=kmax,
|
||||||
Nk=Nk,
|
Nk=Nk)
|
||||||
Ncrop=args.crop,
|
|
||||||
)
|
|
||||||
ax.legend()
|
ax.legend()
|
||||||
if args.title is not None:
|
if args.title is not None:
|
||||||
ax.set_title(args.title)
|
ax.set_title(args.title)
|
||||||
|
@ -387,9 +353,7 @@ if __name__ == "__main__":
|
||||||
bound2=0.002,
|
bound2=0.002,
|
||||||
kmin=kmin,
|
kmin=kmin,
|
||||||
kmax=kmax,
|
kmax=kmax,
|
||||||
Nk=Nk,
|
Nk=Nk)
|
||||||
Ncrop=args.crop,
|
|
||||||
)
|
|
||||||
ax.legend(loc='lower left')
|
ax.legend(loc='lower left')
|
||||||
if args.title is not None:
|
if args.title is not None:
|
||||||
ax.set_title(args.title)
|
ax.set_title(args.title)
|
||||||
|
|
|
@ -2,21 +2,10 @@ import numpy as np
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.append('/home/aubin/Simbelmyne/sbmy_control/')
|
sys.path.append('/home/aubin/Simbelmyne/sbmy_control/')
|
||||||
|
from cosmo_params import register_arguments_cosmo, parse_arguments_cosmo
|
||||||
|
|
||||||
fs = 18
|
fs = 18
|
||||||
fs_titles = fs - 4
|
fs_titles = fs -4
|
||||||
|
|
||||||
def add_ax_ticks(ax, ticks, tick_labels):
|
|
||||||
from matplotlib import ticker
|
|
||||||
ax.set_xticks(ticks)
|
|
||||||
ax.set_yticks(ticks)
|
|
||||||
ax.set_xticklabels(tick_labels)
|
|
||||||
ax.set_yticklabels(tick_labels)
|
|
||||||
ax.set_xlabel('Mpc/h')
|
|
||||||
ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
|
||||||
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot_imshow_with_reference( data_list,
|
def plot_imshow_with_reference( data_list,
|
||||||
reference=None,
|
reference=None,
|
||||||
|
@ -24,9 +13,7 @@ def plot_imshow_with_reference( data_list,
|
||||||
vmin=None,
|
vmin=None,
|
||||||
vmax=None,
|
vmax=None,
|
||||||
L=None,
|
L=None,
|
||||||
cmap='viridis',
|
cmap='viridis'):
|
||||||
cmap_diff='PuOr',
|
|
||||||
ref_label="Reference"):
|
|
||||||
"""
|
"""
|
||||||
Plot the imshow of a list of 2D arrays with two rows: one for the data itself,
|
Plot the imshow of a list of 2D arrays with two rows: one for the data itself,
|
||||||
one for the data compared to a reference. Each row will have a common colorbar.
|
one for the data compared to a reference. Each row will have a common colorbar.
|
||||||
|
@ -38,9 +25,7 @@ def plot_imshow_with_reference( data_list,
|
||||||
- cmap: colormap to be used for plotting
|
- cmap: colormap to be used for plotting
|
||||||
"""
|
"""
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import ticker
|
||||||
from colormaps import register_colormaps
|
|
||||||
register_colormaps(plt.colormaps)
|
|
||||||
|
|
||||||
if titles is None:
|
if titles is None:
|
||||||
titles = [None for f in data_list]
|
titles = [None for f in data_list]
|
||||||
|
@ -58,7 +43,7 @@ def plot_imshow_with_reference( data_list,
|
||||||
return np.linalg.norm(data-reference)/np.linalg.norm(reference)
|
return np.linalg.norm(data-reference)/np.linalg.norm(reference)
|
||||||
|
|
||||||
n = len(data_list)
|
n = len(data_list)
|
||||||
fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2), squeeze = False)
|
fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2))
|
||||||
|
|
||||||
if vmin is None or vmax is None:
|
if vmin is None or vmax is None:
|
||||||
vmin = min(np.quantile(data,0.01) for data in data_list)
|
vmin = min(np.quantile(data,0.01) for data in data_list)
|
||||||
|
@ -67,88 +52,72 @@ def plot_imshow_with_reference( data_list,
|
||||||
if reference is not None:
|
if reference is not None:
|
||||||
vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list)
|
vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list)
|
||||||
vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list)
|
vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list)
|
||||||
vmin_diff = min(vmin_diff, -vmax_diff)
|
|
||||||
vmax_diff = -vmin_diff
|
|
||||||
else:
|
else:
|
||||||
vmin_diff = vmin
|
vmin_diff = vmin
|
||||||
vmax_diff = vmax
|
vmax_diff = vmax
|
||||||
|
|
||||||
# Plot the data itself
|
|
||||||
for i, data in enumerate(data_list):
|
|
||||||
im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
|
|
||||||
axes[0, i].set_title(titles[i], fontsize=fs_titles)
|
|
||||||
add_ax_ticks(axes[0, i], ticks[i], tick_labels[i])
|
|
||||||
fig.colorbar(im, ax=axes[0, :], orientation='vertical')
|
|
||||||
|
|
||||||
if reference is not None:
|
if reference is not None:
|
||||||
|
# Plot the data itself
|
||||||
|
for i, data in enumerate(data_list):
|
||||||
|
im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
|
||||||
|
axes[0, i].set_title(titles[i], fontsize=fs_titles)
|
||||||
|
axes[0, i].set_xticks(ticks[i])
|
||||||
|
axes[0, i].set_yticks(ticks[i])
|
||||||
|
axes[0, i].set_xticklabels(tick_labels[i])
|
||||||
|
axes[0, i].set_yticklabels(tick_labels[i])
|
||||||
|
axes[0, i].set_xlabel('Mpc/h')
|
||||||
|
axes[0, i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
||||||
|
axes[0, i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
||||||
|
fig.colorbar(im, ax=axes[0, :], orientation='vertical')
|
||||||
|
|
||||||
# Plot the data compared to the reference
|
# Plot the data compared to the reference
|
||||||
for i, data in enumerate(data_list):
|
for i, data in enumerate(data_list):
|
||||||
im = axes[1, i].imshow(data - reference, cmap=cmap_diff, origin='lower', vmin=vmin_diff, vmax=vmax_diff)
|
im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff)
|
||||||
axes[1, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles)
|
axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles)
|
||||||
add_ax_ticks(axes[1, i], ticks[i], tick_labels[i])
|
axes[1, i].set_xticks(ticks[i])
|
||||||
|
axes[1, i].set_yticks(ticks[i])
|
||||||
|
axes[1, i].set_xticklabels(tick_labels[i])
|
||||||
|
axes[1, i].set_yticklabels(tick_labels[i])
|
||||||
|
axes[1, i].set_xlabel('Mpc/h')
|
||||||
|
axes[1, i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
||||||
|
axes[1, i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
||||||
fig.colorbar(im, ax=axes[1, :], orientation='vertical')
|
fig.colorbar(im, ax=axes[1, :], orientation='vertical')
|
||||||
|
|
||||||
# Add the score on the plots
|
# Add the score on the plots
|
||||||
for i, data in enumerate(data_list):
|
for i, data in enumerate(data_list):
|
||||||
axes[1, i].text(0.5, 0.9, f"RMS: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='black')
|
axes[1, i].text(0.5, 0.9, f"RMS: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white')
|
||||||
# plt.tight_layout()
|
# plt.tight_layout()
|
||||||
|
else:
|
||||||
|
|
||||||
|
if len(data_list) == 1:
|
||||||
|
data_list = data_list[0]
|
||||||
|
im = axes.imshow(data_list, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
|
||||||
|
axes.set_title(titles[0], fontsize=fs_titles)
|
||||||
|
axes.set_xticks(ticks[0])
|
||||||
|
axes.set_yticks(ticks[0])
|
||||||
|
axes.set_xticklabels(tick_labels[0])
|
||||||
|
axes.set_yticklabels(tick_labels[0])
|
||||||
|
axes.set_xlabel('Mpc/h')
|
||||||
|
axes.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
||||||
|
axes.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
||||||
|
fig.colorbar(im, ax=axes, orientation='vertical')
|
||||||
|
|
||||||
|
else:
|
||||||
|
for i, data in enumerate(data_list):
|
||||||
|
im = axes[i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
|
||||||
|
axes[i].set_title(titles[i], fontsize=fs_titles)
|
||||||
|
axes[i].set_xticks(ticks[i])
|
||||||
|
axes[i].set_yticks(ticks[i])
|
||||||
|
axes[i].set_xticklabels(tick_labels[i])
|
||||||
|
axes[i].set_yticklabels(tick_labels[i])
|
||||||
|
axes[i].set_xlabel('Mpc/h')
|
||||||
|
axes[i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
||||||
|
axes[i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
|
||||||
|
fig.colorbar(im, ax=axes[:], orientation='vertical')
|
||||||
|
|
||||||
return fig, axes
|
return fig, axes
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def plot_imshow_diff(data_list,
|
|
||||||
reference,
|
|
||||||
titles,
|
|
||||||
vmin=None,
|
|
||||||
vmax=None,
|
|
||||||
L=None,
|
|
||||||
cmap='viridis',
|
|
||||||
ref_label="Reference"):
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from colormaps import register_colormaps
|
|
||||||
register_colormaps(plt.colormaps)
|
|
||||||
|
|
||||||
if reference is None:
|
|
||||||
raise ValueError("Reference field is None")
|
|
||||||
|
|
||||||
if titles is None:
|
|
||||||
titles = [None for f in data_list]
|
|
||||||
|
|
||||||
if L is None:
|
|
||||||
L = [len(data) for data in data_list]
|
|
||||||
elif isinstance(L, int) or isinstance(L, float):
|
|
||||||
L = [L for data in data_list]
|
|
||||||
|
|
||||||
sep = 10 if L[0] < 50 else 20 if L[0] < 100 else 50 if L[0]<250 else 100 if L[0] < 500 else 200 if L[0] < 1000 else 500 if L[0] < 2500 else 1000
|
|
||||||
ticks = [np.arange(0, l+1, sep)*len(dat)/l for l, dat in zip(L,data_list)]
|
|
||||||
tick_labels = [np.arange(0, l+1, sep) for l in L]
|
|
||||||
|
|
||||||
def score(data, reference):
|
|
||||||
return np.linalg.norm(data-reference)/np.linalg.norm(reference)
|
|
||||||
|
|
||||||
n = len(data_list)
|
|
||||||
fig, axes = plt.subplots(1, n, figsize=(5 * n, 5), dpi=max(500, data_list[0].shape[0]//2), squeeze = False)
|
|
||||||
|
|
||||||
if vmin is None or vmax is None:
|
|
||||||
vmin = min(np.quantile(data-reference,0.01) for data in data_list)
|
|
||||||
vmax = max(np.quantile(data-reference,0.99) for data in data_list)
|
|
||||||
vmin = min(vmin, -vmax)
|
|
||||||
vmax = -vmin
|
|
||||||
|
|
||||||
# Plot the data compared to the reference
|
|
||||||
for i, data in enumerate(data_list):
|
|
||||||
im = axes[0, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax)
|
|
||||||
axes[0, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles)
|
|
||||||
add_ax_ticks(axes[0, i], ticks[i], tick_labels[i])
|
|
||||||
fig.colorbar(im, ax=axes[0, :], orientation='vertical')
|
|
||||||
|
|
||||||
return fig, axes
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
parser = ArgumentParser(description='Comparisons of fields slices.')
|
parser = ArgumentParser(description='Comparisons of fields slices.')
|
||||||
|
@ -165,24 +134,16 @@ if __name__ == "__main__":
|
||||||
parser.add_argument('-vmax', type=float, default=None, help='Maximum value for the colorbar.')
|
parser.add_argument('-vmax', type=float, default=None, help='Maximum value for the colorbar.')
|
||||||
parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.')
|
parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.')
|
||||||
parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.')
|
parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.')
|
||||||
parser.add_argument('--diff', action='store_true', help='Plot only the difference with the reference field.')
|
|
||||||
parser.add_argument('--ref_label', type=str, default='Reference', help='Label for the reference field.')
|
# register_arguments_cosmo(parser)
|
||||||
parser.add_argument('--cmap_diff', type=str, default='PuOr', help='Colormap to be used for the difference plot.')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
from pysbmy.field import read_field
|
from pysbmy.field import read_field
|
||||||
|
# from pysbmy.cosmology import d_plus
|
||||||
|
|
||||||
ref_label = args.ref_label
|
|
||||||
ref_field = read_field(args.directory+args.reference) if args.reference is not None else None
|
ref_field = read_field(args.directory+args.reference) if args.reference is not None else None
|
||||||
fields = []
|
fields = [read_field(args.directory+f) for f in args.filenames]
|
||||||
for k,f in enumerate(args.filenames):
|
|
||||||
if args.reference is not None and f == args.reference:
|
|
||||||
fields.append(ref_field) # Simply copy the reference field instead of reading it again
|
|
||||||
if args.labels is not None:
|
|
||||||
ref_label = args.labels[k] # Use the label of the field as the reference label
|
|
||||||
else:
|
|
||||||
fields.append(read_field(args.directory+f))
|
|
||||||
|
|
||||||
if args.index is None:
|
if args.index is None:
|
||||||
index = fields[0].N0//2
|
index = fields[0].N0//2
|
||||||
|
@ -196,6 +157,10 @@ if __name__ == "__main__":
|
||||||
case 0 | 'x':
|
case 0 | 'x':
|
||||||
reference = ref_field.data[index,:,:] if ref_field is not None else None
|
reference = ref_field.data[index,:,:] if ref_field is not None else None
|
||||||
fields = [f.data[index,:,:] for f in fields]
|
fields = [f.data[index,:,:] for f in fields]
|
||||||
|
# reference = ref_field.data[index,:,:]/d_plus(1e-3,ref_field.time,parse_arguments_cosmo(args))
|
||||||
|
# fields = [f.data[index,:,:]/d_plus(1e-3,f.time,parse_arguments_cosmo(args)) for f in fields]
|
||||||
|
# reference = ref_field.data[index,:,:]/d_plus(1e-3,0.05,parse_arguments_cosmo(args))
|
||||||
|
# fields = [f.data[index,:,:]/d_plus(1e-3,time,parse_arguments_cosmo(args)) for f,time in zip(fields,[0.05, 1.0])]
|
||||||
case 1 | 'y':
|
case 1 | 'y':
|
||||||
reference = ref_field.data[:,index,:] if ref_field is not None else None
|
reference = ref_field.data[:,index,:] if ref_field is not None else None
|
||||||
fields = [f.data[:,index,:] for f in fields]
|
fields = [f.data[:,index,:] for f in fields]
|
||||||
|
@ -210,10 +175,8 @@ if __name__ == "__main__":
|
||||||
fields = [np.log10(2.+f) for f in fields]
|
fields = [np.log10(2.+f) for f in fields]
|
||||||
|
|
||||||
|
|
||||||
if args.diff:
|
|
||||||
fig, axes = plot_imshow_diff(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap_diff, L=L, ref_label=ref_label)
|
fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L)
|
||||||
else:
|
|
||||||
fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L, ref_label=ref_label, cmap_diff=args.cmap_diff)
|
|
||||||
fig.suptitle(args.title)
|
fig.suptitle(args.title)
|
||||||
|
|
||||||
if args.output is not None:
|
if args.output is not None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue