mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-05-20 17:41:13 +00:00
Add catalogue masking
This commit is contained in:
parent
5090523a10
commit
31a90ea0e4
1 changed files with 73 additions and 50 deletions
|
@ -68,6 +68,8 @@ class BaseCatalogue(ABC):
|
|||
_cache_maxsize = None
|
||||
_catalogue_length = None
|
||||
_is_closed = None
|
||||
_load_filtered = False
|
||||
_filter_mask = None
|
||||
|
||||
_derived_properties = ["cartesian_pos",
|
||||
"spherical_pos",
|
||||
|
@ -81,12 +83,15 @@ class BaseCatalogue(ABC):
|
|||
]
|
||||
|
||||
def __init__(self, simname, nsim, nsnap, halo_finder, catalogue_name,
|
||||
paths, mass_key, cache_maxsize=64):
|
||||
paths, mass_key, bounds, observer_location, observer_velocity,
|
||||
cache_maxsize=64):
|
||||
self.simname = simname
|
||||
self.nsim = nsim
|
||||
self.nsnap = nsnap
|
||||
self.paths = paths
|
||||
self.mass_key = mass_key
|
||||
self.observer_location = observer_location
|
||||
self.observer_velocity = observer_velocity
|
||||
|
||||
fname = self.paths.processed_output(nsim, simname, halo_finder)
|
||||
fprint(f"opening `{fname}`.")
|
||||
|
@ -96,6 +101,9 @@ class BaseCatalogue(ABC):
|
|||
self.cache_maxsize = cache_maxsize
|
||||
self.catalogue_name = catalogue_name
|
||||
|
||||
if bounds is not None:
|
||||
self._make_mask(bounds)
|
||||
|
||||
@property
|
||||
def simname(self):
|
||||
"""Simulation name."""
|
||||
|
@ -185,7 +193,12 @@ class BaseCatalogue(ABC):
|
|||
assert isinstance(cache_maxsize, int)
|
||||
self._cache_maxsize = cache_maxsize
|
||||
|
||||
def cache_keys(self):
|
||||
"""Keys of the cache dictionary."""
|
||||
return list(self._cache.keys())
|
||||
|
||||
def cache_length(self):
|
||||
"""Length of the cache dictionary."""
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
|
@ -378,39 +391,32 @@ class BaseCatalogue(ABC):
|
|||
|
||||
return dist, indxs
|
||||
|
||||
# def filter_data(self, data, bounds):
|
||||
# """
|
||||
# Filters data based on specified bounds for each key.
|
||||
#
|
||||
# Parameters
|
||||
# ----------
|
||||
# data : structured array
|
||||
# The data to be filtered.
|
||||
# bounds : dict
|
||||
# A dictionary with keys corresponding to data columns or `dist` and
|
||||
# values as a tuple of `(xmin, xmax)`. If `xmin` or `xmax` is `None`,
|
||||
# it defaults to negative infinity and positive infinity,
|
||||
# respectively.
|
||||
#
|
||||
# Returns
|
||||
# -------
|
||||
# structured array
|
||||
# """
|
||||
# for key, (xmin, xmax) in bounds.items():
|
||||
# if key == "dist":
|
||||
# pos = numpy.vstack([data[p] - self.observer_location[i]
|
||||
# for i, p in enumerate("xyz")]).T
|
||||
# values_to_filter = numpy.linalg.norm(pos, axis=1)
|
||||
# else:
|
||||
# values_to_filter = data[key]
|
||||
#
|
||||
# min_bound = xmin if xmin is not None else -numpy.inf
|
||||
# max_bound = xmax if xmax is not None else numpy.inf
|
||||
#
|
||||
# data = data[(values_to_filter > min_bound)
|
||||
# & (values_to_filter <= max_bound)]
|
||||
#
|
||||
# return data
|
||||
def _make_mask(self, bounds):
|
||||
"""
|
||||
Make an internal mask for the catalogue data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bounds : dict
|
||||
A dictionary with keys corresponding to data columns or `dist` and
|
||||
values as a tuple of `(xmin, xmax)`. If `xmin` or `xmax` is `None`,
|
||||
it defaults to negative infinity and positive infinity,
|
||||
respectively.
|
||||
"""
|
||||
self._load_filtered = False
|
||||
|
||||
mask = numpy.ones(len(self), dtype=bool)
|
||||
|
||||
for key, (xmin, xmax) in bounds.items():
|
||||
values_to_filter = self[key]
|
||||
|
||||
if xmin is not None:
|
||||
mask &= (values_to_filter > xmin)
|
||||
if xmax is not None:
|
||||
mask &= (values_to_filter <= xmax)
|
||||
|
||||
self._filter_mask = mask
|
||||
self._load_filtered = True
|
||||
|
||||
def keys(self):
|
||||
"""Catalogue keys."""
|
||||
|
@ -433,36 +439,49 @@ class BaseCatalogue(ABC):
|
|||
return keys
|
||||
|
||||
def __getitem__(self, key):
|
||||
# We do not cache the snapshot keys.
|
||||
if "snapshot" in key:
|
||||
return self.data[key]
|
||||
|
||||
# For non-snapshot keys, we cache the results.
|
||||
# For internal calls we don't want to load the filtered data and use
|
||||
# the __ prefixed keys. The internal calls are not being cached.
|
||||
if key.startswith("__"):
|
||||
is_internal = True
|
||||
key = key.lstrip("__")
|
||||
else:
|
||||
is_internal = False
|
||||
|
||||
try:
|
||||
return self._cache[key]
|
||||
out = self._cache[key]
|
||||
if self._load_filtered and not is_internal:
|
||||
return out[self._filter_mask]
|
||||
else:
|
||||
return out
|
||||
except KeyError:
|
||||
if key == "cartesian_pos":
|
||||
out = numpy.vstack([self["x"], self["y"], self["z"]]).T
|
||||
out = numpy.vstack([self["__x"], self["__y"], self["__z"]]).T
|
||||
elif key == "spherical_pos":
|
||||
out = cartesian_to_radec(
|
||||
self["cartesian_pos"] - self.observer_location)
|
||||
self["__cartesian_pos"] - self.observer_location)
|
||||
elif key == "dist":
|
||||
out = numpy.linalg.norm(
|
||||
self["cartesian_pos"] - self.observer_location, axis=1)
|
||||
self["__cartesian_pos"] - self.observer_location, axis=1)
|
||||
elif key == "cartesian_vel":
|
||||
out = numpy.vstack([self["vx"], self["vy"], self["vz"]])
|
||||
out = numpy.vstack([self["__vx"], self["__vy"], self["__vz"]])
|
||||
elif key == "cartesian_redshift_pos":
|
||||
out = real2redshift(
|
||||
self["cartesian_pos"], self["cartesian_vel"],
|
||||
self["__cartesian_pos"], self["__cartesian_vel"],
|
||||
self.observer_location, self.observer_velocity, self.box,
|
||||
make_copy=False)
|
||||
elif key == "spherical_redshift_pos":
|
||||
out = cartesian_to_radec(
|
||||
self["cartesian_redshift_pos"] - self.observer_location)
|
||||
self["__cartesian_redshift_pos"] - self.observer_location)
|
||||
elif key == "redshift_dist":
|
||||
out = self["cartesian_redshift_pos"]
|
||||
out = self["__cartesian_redshift_pos"]
|
||||
out = numpy.linalg.norm(out - self.observer_location, axis=1)
|
||||
elif key == "angular_momentum":
|
||||
out = numpy.vstack([self["Lx"], self["Ly"], self["Lz"]]).T
|
||||
out = numpy.vstack(
|
||||
[self["__Lx"], self["__Ly"], self["__Lz"]]).T
|
||||
elif key == "particle_offset":
|
||||
out = make_halomap_dict(self["snapshot_final/halo_map"][:])
|
||||
elif key == "npart":
|
||||
|
@ -478,11 +497,16 @@ class BaseCatalogue(ABC):
|
|||
else:
|
||||
raise KeyError(f"Key '{key}' is not available.")
|
||||
|
||||
# TODO: Enfore the masking somewhere here?
|
||||
if not is_internal:
|
||||
self._cache[key] = out
|
||||
|
||||
if self.cache_length() > self.cache_maxsize:
|
||||
self._cache.popitem(last=False)
|
||||
self._cache[key] = out
|
||||
return out
|
||||
|
||||
if self._load_filtered and not is_internal:
|
||||
return out[self._filter_mask]
|
||||
else:
|
||||
return out
|
||||
|
||||
@property
|
||||
def is_closed(self):
|
||||
|
@ -536,14 +560,13 @@ class CSiBORGCatalogue(BaseCatalogue):
|
|||
Observer's velocity in :math:`\mathrm{km} / \mathrm{s}`.
|
||||
"""
|
||||
def __init__(self, nsim, paths, catalogue_name, halo_finder, mass_key=None,
|
||||
observer_velocity=None, cache_maxsize=64):
|
||||
bounds=None, observer_velocity=None, cache_maxsize=64):
|
||||
super().__init__("csiborg", nsim,
|
||||
max(paths.get_snapshots(nsim, "csiborg")),
|
||||
halo_finder, catalogue_name, paths, mass_key,
|
||||
bounds, [338.85, 338.85, 338.85], observer_velocity,
|
||||
cache_maxsize)
|
||||
self.box = CSiBORGBox(self.nsnap, self.nsim, self.paths)
|
||||
self.observer_location = [338.85, 338.85, 338.85] # Mpc / h
|
||||
self.observer_velocity = observer_velocity
|
||||
|
||||
###############################################################################
|
||||
# Quijote halo catalogue #
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue