diff --git a/analysis/colormaps.py b/analysis/colormaps.py deleted file mode 100644 index 54ddf3a..0000000 --- a/analysis/colormaps.py +++ /dev/null @@ -1,36 +0,0 @@ - - -def register_colormaps(colormaps): - - # Register cmasher - try: - import cmasher as cma - for name, cmap in cma.cm.cmap_d.items(): - try: - colormaps.register(name=name, cmap=cmap) - except ValueError: - pass - except ImportError: - pass - - # Register cmocean - try: - import cmocean as cmo - for name, cmap in cmo.cm.cmap_d.items(): - try: - colormaps.register(name=name, cmap=cmap) - except ValueError: - pass - except ImportError: - pass - - # Register cmcrameri - try: - import cmcrameri as cmc - for name, cmap in cmc.cm.cmaps.items(): - try: - colormaps.register(name=name, cmap=cmap) - except ValueError: - pass - except ImportError: - pass diff --git a/analysis/power_spectrum.py b/analysis/power_spectrum.py index 9d44005..93bb8ee 100644 --- a/analysis/power_spectrum.py +++ b/analysis/power_spectrum.py @@ -5,24 +5,6 @@ kmax = 2e0 Nk = 50 AliasingCorr=False -def crop_field(field, Ncrop): - - if Ncrop is None or Ncrop == 0: - return - - elif Ncrop > 0: - field.data = field.data[Ncrop:-Ncrop, Ncrop:-Ncrop, Ncrop:-Ncrop] - d0 = field.L0/field.N0 - d1 = field.L1/field.N1 - d2 = field.L2/field.N2 - field.N0 -= 2*Ncrop - field.N1 -= 2*Ncrop - field.N2 -= 2*Ncrop - field.L0 = field.N0*d0 - field.L1 = field.N1*d1 - field.L2 = field.N2*d2 - - def get_power_spectrum(field, kmin=kmin, kmax=kmax, Nk=Nk, G=None): from pysbmy.power import PowerSpectrum from pysbmy.fft import FourierGrid @@ -109,8 +91,7 @@ def plot_power_spectra(filenames, figsize=(8,4), dpi=300, ax=None, - fig=None, - Ncrop=None,): + fig=None,): import matplotlib.pyplot as plt from pysbmy.field import read_field @@ -129,7 +110,6 @@ def plot_power_spectra(filenames, for i, filename in enumerate(filenames): field = read_field(filename) - crop_field(field, Ncrop) _, G, k, _ = add_power_spectrum_to_plot(ax=ax, field=field, Pk_ref=Pk_ref, @@ -148,7 +128,7 @@ def plot_power_spectra(filenames, ax.set_ylim(ylims) if yticks is not None: ax.set_yticks(yticks) - ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]', labelpad=-0) + ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]', labelpad=-10) if Pk_ref is not None: ax.set_ylabel(r'$P(k)/P_\mathrm{ref}(k)$') @@ -183,9 +163,7 @@ def plot_cross_correlations(filenames_A, figsize=(8,4), dpi=300, ax=None, - fig=None, - Ncrop=None, - ): + fig=None,): import matplotlib.pyplot as plt from pysbmy.field import read_field @@ -203,11 +181,9 @@ def plot_cross_correlations(filenames_A, markers = [None for f in filenames_A] field_B = read_field(filename_B) - crop_field(field_B, Ncrop) for i, filename_A in enumerate(filenames_A): field_A = read_field(filename_A) - crop_field(field_A, Ncrop) _, G, k, _ = add_cross_correlations_to_plot(ax=ax, field_A=field_A, field_B=field_B, @@ -226,7 +202,7 @@ def plot_cross_correlations(filenames_A, ax.set_ylim(ylims) if yticks is not None: ax.set_yticks(yticks) - ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]', labelpad=-0) + ax.set_xlabel(r'$k$ [$h/\mathrm{Mpc}$]', labelpad=-10) ax.set_ylabel('$R(k)$') if bound1 is not None: @@ -279,7 +255,6 @@ if __name__ == "__main__": parser.add_argument('-t','--title', type=str, default=None, help='Title of the plot.') parser.add_argument('-yrp', '--ylim_power', type=float, nargs=2, default=[0.9,1.1], help='Y-axis limits.') parser.add_argument('-yrc', '--ylim_corr', type=float, nargs=2, default=[0.99,1.001], help='Y-axis limits.') - parser.add_argument('--crop', type=int, default=None, help='Remove the outter N pixels of the fields.') args = parser.parse_args() @@ -289,9 +264,7 @@ if __name__ == "__main__": if args.reference is not None: from pysbmy.field import read_field - F_ref = read_field(args.directory+args.reference) - crop_field(F_ref, args.crop) - G, _, Pk_ref = get_power_spectrum(F_ref, kmin=kmin, kmax=kmax, Nk=Nk) + G, _, Pk_ref = get_power_spectrum(read_field(args.directory+args.reference), kmin=kmin, kmax=kmax, Nk=Nk) else: Pk_ref = None G = None @@ -306,7 +279,6 @@ if __name__ == "__main__": if args.power_spectrum and args.cross_correlation: import matplotlib.pyplot as plt fig, axes = plt.subplots(2, 1, figsize=(8,8)) - fig.subplots_adjust(hspace=0.3) plot_power_spectra(filenames=filenames, labels=args.labels, colors=args.colors, @@ -322,9 +294,7 @@ if __name__ == "__main__": kmax=kmax, Nk=Nk, ax=axes[0], - fig=fig, - Ncrop=args.crop, - ) + fig=fig) plot_cross_correlations(filenames_A=filenames, filename_B=args.directory+args.reference, @@ -341,9 +311,7 @@ if __name__ == "__main__": kmax=kmax, Nk=Nk, ax=axes[1], - fig=fig, - Ncrop=args.crop, - ) + fig=fig) axes[1].legend(loc='lower left') axes[0].set_title("Power Spectrum") @@ -366,9 +334,7 @@ if __name__ == "__main__": bound2=0.02, kmin=kmin, kmax=kmax, - Nk=Nk, - Ncrop=args.crop, - ) + Nk=Nk) ax.legend() if args.title is not None: ax.set_title(args.title) @@ -387,9 +353,7 @@ if __name__ == "__main__": bound2=0.002, kmin=kmin, kmax=kmax, - Nk=Nk, - Ncrop=args.crop, - ) + Nk=Nk) ax.legend(loc='lower left') if args.title is not None: ax.set_title(args.title) diff --git a/analysis/slices.py b/analysis/slices.py index 2febe16..4f01418 100644 --- a/analysis/slices.py +++ b/analysis/slices.py @@ -2,21 +2,10 @@ import numpy as np import sys sys.path.append('/home/aubin/Simbelmyne/sbmy_control/') +from cosmo_params import register_arguments_cosmo, parse_arguments_cosmo fs = 18 -fs_titles = fs - 4 - -def add_ax_ticks(ax, ticks, tick_labels): - from matplotlib import ticker - ax.set_xticks(ticks) - ax.set_yticks(ticks) - ax.set_xticklabels(tick_labels) - ax.set_yticklabels(tick_labels) - ax.set_xlabel('Mpc/h') - ax.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) - ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) - - +fs_titles = fs -4 def plot_imshow_with_reference( data_list, reference=None, @@ -24,9 +13,7 @@ def plot_imshow_with_reference( data_list, vmin=None, vmax=None, L=None, - cmap='viridis', - cmap_diff='PuOr', - ref_label="Reference"): + cmap='viridis'): """ Plot the imshow of a list of 2D arrays with two rows: one for the data itself, one for the data compared to a reference. Each row will have a common colorbar. @@ -38,9 +25,7 @@ def plot_imshow_with_reference( data_list, - cmap: colormap to be used for plotting """ import matplotlib.pyplot as plt - - from colormaps import register_colormaps - register_colormaps(plt.colormaps) + from matplotlib import ticker if titles is None: titles = [None for f in data_list] @@ -58,7 +43,7 @@ def plot_imshow_with_reference( data_list, return np.linalg.norm(data-reference)/np.linalg.norm(reference) n = len(data_list) - fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2), squeeze = False) + fig, axes = plt.subplots(1 if reference is None else 2, n, figsize=(5 * n, 5 if reference is None else 5*2), dpi=max(500, data_list[0].shape[0]//2)) if vmin is None or vmax is None: vmin = min(np.quantile(data,0.01) for data in data_list) @@ -67,88 +52,72 @@ def plot_imshow_with_reference( data_list, if reference is not None: vmin_diff = min(np.quantile((data-reference),0.01) for data in data_list) vmax_diff = max(np.quantile((data-reference),0.99) for data in data_list) - vmin_diff = min(vmin_diff, -vmax_diff) - vmax_diff = -vmin_diff else: vmin_diff = vmin vmax_diff = vmax - # Plot the data itself - for i, data in enumerate(data_list): - im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) - axes[0, i].set_title(titles[i], fontsize=fs_titles) - add_ax_ticks(axes[0, i], ticks[i], tick_labels[i]) - fig.colorbar(im, ax=axes[0, :], orientation='vertical') - if reference is not None: + # Plot the data itself + for i, data in enumerate(data_list): + im = axes[0, i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) + axes[0, i].set_title(titles[i], fontsize=fs_titles) + axes[0, i].set_xticks(ticks[i]) + axes[0, i].set_yticks(ticks[i]) + axes[0, i].set_xticklabels(tick_labels[i]) + axes[0, i].set_yticklabels(tick_labels[i]) + axes[0, i].set_xlabel('Mpc/h') + axes[0, i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + axes[0, i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + fig.colorbar(im, ax=axes[0, :], orientation='vertical') + # Plot the data compared to the reference for i, data in enumerate(data_list): - im = axes[1, i].imshow(data - reference, cmap=cmap_diff, origin='lower', vmin=vmin_diff, vmax=vmax_diff) - axes[1, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles) - add_ax_ticks(axes[1, i], ticks[i], tick_labels[i]) + im = axes[1, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin_diff, vmax=vmax_diff) + axes[1, i].set_title(f'{titles[i]} - Reference', fontsize=fs_titles) + axes[1, i].set_xticks(ticks[i]) + axes[1, i].set_yticks(ticks[i]) + axes[1, i].set_xticklabels(tick_labels[i]) + axes[1, i].set_yticklabels(tick_labels[i]) + axes[1, i].set_xlabel('Mpc/h') + axes[1, i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + axes[1, i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) fig.colorbar(im, ax=axes[1, :], orientation='vertical') # Add the score on the plots for i, data in enumerate(data_list): - axes[1, i].text(0.5, 0.9, f"RMS: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='black') + axes[1, i].text(0.5, 0.9, f"RMS: {score(data, reference):.2e}", fontsize=10, transform=axes[1, i].transAxes, color='white') # plt.tight_layout() + else: + + if len(data_list) == 1: + data_list = data_list[0] + im = axes.imshow(data_list, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) + axes.set_title(titles[0], fontsize=fs_titles) + axes.set_xticks(ticks[0]) + axes.set_yticks(ticks[0]) + axes.set_xticklabels(tick_labels[0]) + axes.set_yticklabels(tick_labels[0]) + axes.set_xlabel('Mpc/h') + axes.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + axes.yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + fig.colorbar(im, ax=axes, orientation='vertical') + + else: + for i, data in enumerate(data_list): + im = axes[i].imshow(data, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) + axes[i].set_title(titles[i], fontsize=fs_titles) + axes[i].set_xticks(ticks[i]) + axes[i].set_yticks(ticks[i]) + axes[i].set_xticklabels(tick_labels[i]) + axes[i].set_yticklabels(tick_labels[i]) + axes[i].set_xlabel('Mpc/h') + axes[i].xaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + axes[i].yaxis.set_major_formatter(ticker.FormatStrFormatter('%d')) + fig.colorbar(im, ax=axes[:], orientation='vertical') return fig, axes - -def plot_imshow_diff(data_list, - reference, - titles, - vmin=None, - vmax=None, - L=None, - cmap='viridis', - ref_label="Reference"): - - import matplotlib.pyplot as plt - - from colormaps import register_colormaps - register_colormaps(plt.colormaps) - - if reference is None: - raise ValueError("Reference field is None") - - if titles is None: - titles = [None for f in data_list] - - if L is None: - L = [len(data) for data in data_list] - elif isinstance(L, int) or isinstance(L, float): - L = [L for data in data_list] - - sep = 10 if L[0] < 50 else 20 if L[0] < 100 else 50 if L[0]<250 else 100 if L[0] < 500 else 200 if L[0] < 1000 else 500 if L[0] < 2500 else 1000 - ticks = [np.arange(0, l+1, sep)*len(dat)/l for l, dat in zip(L,data_list)] - tick_labels = [np.arange(0, l+1, sep) for l in L] - - def score(data, reference): - return np.linalg.norm(data-reference)/np.linalg.norm(reference) - - n = len(data_list) - fig, axes = plt.subplots(1, n, figsize=(5 * n, 5), dpi=max(500, data_list[0].shape[0]//2), squeeze = False) - - if vmin is None or vmax is None: - vmin = min(np.quantile(data-reference,0.01) for data in data_list) - vmax = max(np.quantile(data-reference,0.99) for data in data_list) - vmin = min(vmin, -vmax) - vmax = -vmin - - # Plot the data compared to the reference - for i, data in enumerate(data_list): - im = axes[0, i].imshow(data - reference, cmap=cmap, origin='lower', vmin=vmin, vmax=vmax) - axes[0, i].set_title(f'{titles[i]} - {ref_label}', fontsize=fs_titles) - add_ax_ticks(axes[0, i], ticks[i], tick_labels[i]) - fig.colorbar(im, ax=axes[0, :], orientation='vertical') - - return fig, axes - - - if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser(description='Comparisons of fields slices.') @@ -165,24 +134,16 @@ if __name__ == "__main__": parser.add_argument('-vmax', type=float, default=None, help='Maximum value for the colorbar.') parser.add_argument('-t', '--title', type=str, default=None, help='Title for the plot.') parser.add_argument('-log','--log_scale', action='store_true', help='Use log scale for the data.') - parser.add_argument('--diff', action='store_true', help='Plot only the difference with the reference field.') - parser.add_argument('--ref_label', type=str, default='Reference', help='Label for the reference field.') - parser.add_argument('--cmap_diff', type=str, default='PuOr', help='Colormap to be used for the difference plot.') + + # register_arguments_cosmo(parser) args = parser.parse_args() from pysbmy.field import read_field + # from pysbmy.cosmology import d_plus - ref_label = args.ref_label ref_field = read_field(args.directory+args.reference) if args.reference is not None else None - fields = [] - for k,f in enumerate(args.filenames): - if args.reference is not None and f == args.reference: - fields.append(ref_field) # Simply copy the reference field instead of reading it again - if args.labels is not None: - ref_label = args.labels[k] # Use the label of the field as the reference label - else: - fields.append(read_field(args.directory+f)) + fields = [read_field(args.directory+f) for f in args.filenames] if args.index is None: index = fields[0].N0//2 @@ -196,6 +157,10 @@ if __name__ == "__main__": case 0 | 'x': reference = ref_field.data[index,:,:] if ref_field is not None else None fields = [f.data[index,:,:] for f in fields] + # reference = ref_field.data[index,:,:]/d_plus(1e-3,ref_field.time,parse_arguments_cosmo(args)) + # fields = [f.data[index,:,:]/d_plus(1e-3,f.time,parse_arguments_cosmo(args)) for f in fields] + # reference = ref_field.data[index,:,:]/d_plus(1e-3,0.05,parse_arguments_cosmo(args)) + # fields = [f.data[index,:,:]/d_plus(1e-3,time,parse_arguments_cosmo(args)) for f,time in zip(fields,[0.05, 1.0])] case 1 | 'y': reference = ref_field.data[:,index,:] if ref_field is not None else None fields = [f.data[:,index,:] for f in fields] @@ -210,10 +175,8 @@ if __name__ == "__main__": fields = [np.log10(2.+f) for f in fields] - if args.diff: - fig, axes = plot_imshow_diff(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap_diff, L=L, ref_label=ref_label) - else: - fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L, ref_label=ref_label, cmap_diff=args.cmap_diff) + + fig, axes = plot_imshow_with_reference(fields,reference,args.labels, vmin=args.vmin, vmax=args.vmax,cmap=args.cmap, L=L) fig.suptitle(args.title) if args.output is not None: