From 31a90ea0e415c6e2c3d05abea2c82635ae7b7f5e Mon Sep 17 00:00:00 2001 From: rstiskalek Date: Fri, 20 Oct 2023 22:12:38 +0100 Subject: [PATCH] Add catalogue masking --- csiborgtools/read/halo_cat.py | 123 ++++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 50 deletions(-) diff --git a/csiborgtools/read/halo_cat.py b/csiborgtools/read/halo_cat.py index 0929773..0409033 100644 --- a/csiborgtools/read/halo_cat.py +++ b/csiborgtools/read/halo_cat.py @@ -68,6 +68,8 @@ class BaseCatalogue(ABC): _cache_maxsize = None _catalogue_length = None _is_closed = None + _load_filtered = False + _filter_mask = None _derived_properties = ["cartesian_pos", "spherical_pos", @@ -81,12 +83,15 @@ class BaseCatalogue(ABC): ] def __init__(self, simname, nsim, nsnap, halo_finder, catalogue_name, - paths, mass_key, cache_maxsize=64): + paths, mass_key, bounds, observer_location, observer_velocity, + cache_maxsize=64): self.simname = simname self.nsim = nsim self.nsnap = nsnap self.paths = paths self.mass_key = mass_key + self.observer_location = observer_location + self.observer_velocity = observer_velocity fname = self.paths.processed_output(nsim, simname, halo_finder) fprint(f"opening `{fname}`.") @@ -96,6 +101,9 @@ class BaseCatalogue(ABC): self.cache_maxsize = cache_maxsize self.catalogue_name = catalogue_name + if bounds is not None: + self._make_mask(bounds) + @property def simname(self): """Simulation name.""" @@ -185,7 +193,12 @@ class BaseCatalogue(ABC): assert isinstance(cache_maxsize, int) self._cache_maxsize = cache_maxsize + def cache_keys(self): + """Keys of the cache dictionary.""" + return list(self._cache.keys()) + def cache_length(self): + """Length of the cache dictionary.""" return len(self._cache) @property @@ -378,39 +391,32 @@ class BaseCatalogue(ABC): return dist, indxs -# def filter_data(self, data, bounds): -# """ -# Filters data based on specified bounds for each key. -# -# Parameters -# ---------- -# data : structured array -# The data to be filtered. -# bounds : dict -# A dictionary with keys corresponding to data columns or `dist` and -# values as a tuple of `(xmin, xmax)`. If `xmin` or `xmax` is `None`, -# it defaults to negative infinity and positive infinity, -# respectively. -# -# Returns -# ------- -# structured array -# """ -# for key, (xmin, xmax) in bounds.items(): -# if key == "dist": -# pos = numpy.vstack([data[p] - self.observer_location[i] -# for i, p in enumerate("xyz")]).T -# values_to_filter = numpy.linalg.norm(pos, axis=1) -# else: -# values_to_filter = data[key] -# -# min_bound = xmin if xmin is not None else -numpy.inf -# max_bound = xmax if xmax is not None else numpy.inf -# -# data = data[(values_to_filter > min_bound) -# & (values_to_filter <= max_bound)] -# -# return data + def _make_mask(self, bounds): + """ + Make an internal mask for the catalogue data. + + Parameters + ---------- + bounds : dict + A dictionary with keys corresponding to data columns or `dist` and + values as a tuple of `(xmin, xmax)`. If `xmin` or `xmax` is `None`, + it defaults to negative infinity and positive infinity, + respectively. + """ + self._load_filtered = False + + mask = numpy.ones(len(self), dtype=bool) + + for key, (xmin, xmax) in bounds.items(): + values_to_filter = self[key] + + if xmin is not None: + mask &= (values_to_filter > xmin) + if xmax is not None: + mask &= (values_to_filter <= xmax) + + self._filter_mask = mask + self._load_filtered = True def keys(self): """Catalogue keys.""" @@ -433,36 +439,49 @@ class BaseCatalogue(ABC): return keys def __getitem__(self, key): + # We do not cache the snapshot keys. if "snapshot" in key: return self.data[key] - # For non-snapshot keys, we cache the results. + # For internal calls we don't want to load the filtered data and use + # the __ prefixed keys. The internal calls are not being cached. + if key.startswith("__"): + is_internal = True + key = key.lstrip("__") + else: + is_internal = False + try: - return self._cache[key] + out = self._cache[key] + if self._load_filtered and not is_internal: + return out[self._filter_mask] + else: + return out except KeyError: if key == "cartesian_pos": - out = numpy.vstack([self["x"], self["y"], self["z"]]).T + out = numpy.vstack([self["__x"], self["__y"], self["__z"]]).T elif key == "spherical_pos": out = cartesian_to_radec( - self["cartesian_pos"] - self.observer_location) + self["__cartesian_pos"] - self.observer_location) elif key == "dist": out = numpy.linalg.norm( - self["cartesian_pos"] - self.observer_location, axis=1) + self["__cartesian_pos"] - self.observer_location, axis=1) elif key == "cartesian_vel": - out = numpy.vstack([self["vx"], self["vy"], self["vz"]]) + out = numpy.vstack([self["__vx"], self["__vy"], self["__vz"]]) elif key == "cartesian_redshift_pos": out = real2redshift( - self["cartesian_pos"], self["cartesian_vel"], + self["__cartesian_pos"], self["__cartesian_vel"], self.observer_location, self.observer_velocity, self.box, make_copy=False) elif key == "spherical_redshift_pos": out = cartesian_to_radec( - self["cartesian_redshift_pos"] - self.observer_location) + self["__cartesian_redshift_pos"] - self.observer_location) elif key == "redshift_dist": - out = self["cartesian_redshift_pos"] + out = self["__cartesian_redshift_pos"] out = numpy.linalg.norm(out - self.observer_location, axis=1) elif key == "angular_momentum": - out = numpy.vstack([self["Lx"], self["Ly"], self["Lz"]]).T + out = numpy.vstack( + [self["__Lx"], self["__Ly"], self["__Lz"]]).T elif key == "particle_offset": out = make_halomap_dict(self["snapshot_final/halo_map"][:]) elif key == "npart": @@ -478,11 +497,16 @@ class BaseCatalogue(ABC): else: raise KeyError(f"Key '{key}' is not available.") - # TODO: Enfore the masking somewhere here? + if not is_internal: + self._cache[key] = out + if self.cache_length() > self.cache_maxsize: self._cache.popitem(last=False) - self._cache[key] = out - return out + + if self._load_filtered and not is_internal: + return out[self._filter_mask] + else: + return out @property def is_closed(self): @@ -536,14 +560,13 @@ class CSiBORGCatalogue(BaseCatalogue): Observer's velocity in :math:`\mathrm{km} / \mathrm{s}`. """ def __init__(self, nsim, paths, catalogue_name, halo_finder, mass_key=None, - observer_velocity=None, cache_maxsize=64): + bounds=None, observer_velocity=None, cache_maxsize=64): super().__init__("csiborg", nsim, max(paths.get_snapshots(nsim, "csiborg")), halo_finder, catalogue_name, paths, mass_key, + bounds, [338.85, 338.85, 338.85], observer_velocity, cache_maxsize) self.box = CSiBORGBox(self.nsnap, self.nsim, self.paths) - self.observer_location = [338.85, 338.85, 338.85] # Mpc / h - self.observer_velocity = observer_velocity ############################################################################### # Quijote halo catalogue #