diff --git a/map2map/utils/figures.py b/map2map/utils/figures.py index 2a2f08c..03d8b7e 100644 --- a/map2map/utils/figures.py +++ b/map2map/utils/figures.py @@ -1,5 +1,4 @@ from math import log2, log10, ceil -import warnings import torch import numpy as np import matplotlib @@ -46,15 +45,14 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None): ) for f, (field, cmap_col, norm_col) in enumerate(zip(fields, cmap, norm)): - all_non_neg = (field >= 0).all() - all_non_pos = (field <= 0).all() + all_non_neg = np.all(field >= 0) + all_non_pos = np.all(field <= 0) if cmap_col is None: if all_non_neg: - cmap_col = 'viridis' + cmap_col = 'inferno' elif all_non_pos: - warnings.warn('no implementation for all non-positive values') - cmap_col = None + cmap_col = 'inferno_r' else: cmap_col = 'RdBu_r' @@ -68,8 +66,12 @@ def plt_slices(*fields, size=64, title=None, cmap=None, norm=None): else: norm_col = LogNorm(vmin=quantize(l2), vmax=quantize(h2)) elif all_non_pos: - warnings.warn('no implementation for all non-positive values yet') - norm_col = None + if l1 < 0.1 * l2: + norm_col = Normalize(vmin=-quantize(-l2), vmax=0) + else: + norm_col = SymLogNorm(linthresh=quantize(-h2), + vmin=-quantize(-l2), + vmax=-quantize(-h2)) else: vlim = quantize(max(-l2, h2)) if w1 > 0.1 * w2 or l1 * h1 >= 0: