mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-23 03:08:01 +00:00
133 lines
4.1 KiB
Python
133 lines
4.1 KiB
Python
|
# Copyright (C) 2023 Richard Stiskalek
|
||
|
# This program is free software; you can redistribute it and/or modify it
|
||
|
# under the terms of the GNU General Public License as published by the
|
||
|
# Free Software Foundation; either version 3 of the License, or (at your
|
||
|
# option) any later version.
|
||
|
#
|
||
|
# This program is distributed in the hope that it will be useful, but
|
||
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
|
||
|
# Public License for more details.
|
||
|
#
|
||
|
# You should have received a copy of the GNU General Public License along
|
||
|
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||
|
"""
|
||
|
Various utility functions.
|
||
|
"""
|
||
|
|
||
|
import numpy
|
||
|
from scipy.special import erf
|
||
|
|
||
|
dpi = 600
|
||
|
fout = "../plots/"
|
||
|
mplstyle = ["science"]
|
||
|
|
||
|
|
||
|
def latex_float(*floats, n=2):
|
||
|
"""
|
||
|
Convert a float or a list of floats to a LaTeX string(s). Taken from [1].
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
floats : float or list of floats
|
||
|
The float(s) to be converted.
|
||
|
n : int, optional
|
||
|
The number of significant figures to be used in the LaTeX string.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
latex_floats : str or list of str
|
||
|
The LaTeX string(s) representing the float(s).
|
||
|
|
||
|
References
|
||
|
----------
|
||
|
[1] https://stackoverflow.com/questions/13490292/format-number-using-latex-notation-in-python # noqa
|
||
|
"""
|
||
|
latex_floats = [None] * len(floats)
|
||
|
for i, f in enumerate(floats):
|
||
|
float_str = "{0:.{1}g}".format(f, n)
|
||
|
if "e" in float_str:
|
||
|
base, exponent = float_str.split("e")
|
||
|
latex_floats[i] = r"{0} \times 10^{{{1}}}".format(base,
|
||
|
int(exponent))
|
||
|
else:
|
||
|
latex_floats[i] = float_str
|
||
|
|
||
|
if len(floats) == 1:
|
||
|
return latex_floats[0]
|
||
|
return latex_floats
|
||
|
|
||
|
|
||
|
def nan_weighted_average(arr, weights=None, axis=None):
|
||
|
if weights is None:
|
||
|
weights = numpy.ones_like(arr)
|
||
|
|
||
|
valid_entries = ~numpy.isnan(arr)
|
||
|
|
||
|
# Set NaN entries in arr to 0 for computation
|
||
|
arr = numpy.where(valid_entries, arr, 0)
|
||
|
|
||
|
# Set weights of NaN entries to 0
|
||
|
weights = numpy.where(valid_entries, weights, 0)
|
||
|
|
||
|
# Compute the weighted sum and the sum of weights along the axis
|
||
|
weighted_sum = numpy.sum(arr * weights, axis=axis)
|
||
|
sum_weights = numpy.sum(weights, axis=axis)
|
||
|
|
||
|
return weighted_sum / sum_weights
|
||
|
|
||
|
|
||
|
def nan_weighted_std(arr, weights=None, axis=None, ddof=0):
|
||
|
if weights is None:
|
||
|
weights = numpy.ones_like(arr)
|
||
|
|
||
|
valid_entries = ~numpy.isnan(arr)
|
||
|
|
||
|
# Set NaN entries in arr to 0 for computation
|
||
|
arr = numpy.where(valid_entries, arr, 0)
|
||
|
|
||
|
# Set weights of NaN entries to 0
|
||
|
weights = numpy.where(valid_entries, weights, 0)
|
||
|
|
||
|
# Calculate weighted mean
|
||
|
weighted_mean = numpy.sum(
|
||
|
arr * weights, axis=axis) / numpy.sum(weights, axis=axis)
|
||
|
|
||
|
# Calculate the weighted variance
|
||
|
variance = numpy.sum(
|
||
|
weights * (arr - numpy.expand_dims(weighted_mean, axis))**2, axis=axis)
|
||
|
variance /= numpy.sum(weights, axis=axis) - ddof
|
||
|
|
||
|
return numpy.sqrt(variance)
|
||
|
|
||
|
|
||
|
def compute_error_bars(x, y, xbins, sigma):
|
||
|
bin_indices = numpy.digitize(x, xbins)
|
||
|
y_medians = numpy.array([numpy.median(y[bin_indices == i])
|
||
|
for i in range(1, len(xbins))])
|
||
|
|
||
|
lower_pct = 100 * 0.5 * (1 - erf(sigma / numpy.sqrt(2)))
|
||
|
upper_pct = 100 - lower_pct
|
||
|
|
||
|
y_lower = numpy.full(len(y_medians), numpy.nan)
|
||
|
y_upper = numpy.full(len(y_medians), numpy.nan)
|
||
|
|
||
|
for i in range(len(y_medians)):
|
||
|
if numpy.sum(bin_indices == i + 1) == 0:
|
||
|
continue
|
||
|
|
||
|
y_lower[i] = numpy.percentile(y[bin_indices == i + 1], lower_pct)
|
||
|
y_upper[i] = numpy.percentile(y[bin_indices == i + 1], upper_pct)
|
||
|
|
||
|
yerr = (y_medians - numpy.array(y_lower), numpy.array(y_upper) - y_medians)
|
||
|
|
||
|
return y_medians, yerr
|
||
|
|
||
|
|
||
|
def normalize_hexbin(hb):
|
||
|
hexagon_counts = hb.get_array()
|
||
|
normalized_counts = hexagon_counts / hexagon_counts.sum()
|
||
|
hb.set_array(normalized_counts)
|
||
|
hb.set_clim(normalized_counts.min(), normalized_counts.max())
|