From 3e81e6f852f6d4bd084da80f2ecd12dc5457022e Mon Sep 17 00:00:00 2001 From: rstiskalek Date: Sat, 21 Oct 2023 11:39:57 +0100 Subject: [PATCH] Change if-else statements --- csiborgtools/read/halo_cat.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/csiborgtools/read/halo_cat.py b/csiborgtools/read/halo_cat.py index 0409033..7602bdd 100644 --- a/csiborgtools/read/halo_cat.py +++ b/csiborgtools/read/halo_cat.py @@ -408,13 +408,14 @@ class BaseCatalogue(ABC): mask = numpy.ones(len(self), dtype=bool) for key, (xmin, xmax) in bounds.items(): - values_to_filter = self[key] + values_to_filter = self[f"__{key}"] if xmin is not None: mask &= (values_to_filter > xmin) if xmax is not None: mask &= (values_to_filter <= xmax) + self.clear_cache() self._filter_mask = mask self._load_filtered = True @@ -441,7 +442,10 @@ class BaseCatalogue(ABC): def __getitem__(self, key): # We do not cache the snapshot keys. if "snapshot" in key: - return self.data[key] + if key in self.data: + return self.data[key] + else: + raise KeyError(f"Key '{key}' is not available.") # For internal calls we don't want to load the filtered data and use # the __ prefixed keys. The internal calls are not being cached. @@ -451,13 +455,13 @@ class BaseCatalogue(ABC): else: is_internal = False - try: + if key in self.cache_keys(): out = self._cache[key] if self._load_filtered and not is_internal: return out[self._filter_mask] else: return out - except KeyError: + else: if key == "cartesian_pos": out = numpy.vstack([self["__x"], self["__y"], self["__z"]]).T elif key == "spherical_pos": @@ -487,7 +491,7 @@ class BaseCatalogue(ABC): elif key == "npart": halomap = self["particle_offset"] out = numpy.zeros(len(halomap), dtype=numpy.int32) - for i, hid in enumerate(self["index"]): + for i, hid in enumerate(self["__index"]): if hid == 0: continue start, end = halomap[hid] @@ -528,7 +532,7 @@ class BaseCatalogue(ABC): def __len__(self): if self._catalogue_length is None: - self._catalogue_length = len(self["index"]) + self._catalogue_length = len(self["__index"]) return self._catalogue_length