Change if-else statements

This commit is contained in:
rstiskalek 2023-10-21 11:39:57 +01:00
parent 31a90ea0e4
commit 3e81e6f852

View file

@ -408,13 +408,14 @@ class BaseCatalogue(ABC):
mask = numpy.ones(len(self), dtype=bool) mask = numpy.ones(len(self), dtype=bool)
for key, (xmin, xmax) in bounds.items(): for key, (xmin, xmax) in bounds.items():
values_to_filter = self[key] values_to_filter = self[f"__{key}"]
if xmin is not None: if xmin is not None:
mask &= (values_to_filter > xmin) mask &= (values_to_filter > xmin)
if xmax is not None: if xmax is not None:
mask &= (values_to_filter <= xmax) mask &= (values_to_filter <= xmax)
self.clear_cache()
self._filter_mask = mask self._filter_mask = mask
self._load_filtered = True self._load_filtered = True
@ -441,7 +442,10 @@ class BaseCatalogue(ABC):
def __getitem__(self, key): def __getitem__(self, key):
# We do not cache the snapshot keys. # We do not cache the snapshot keys.
if "snapshot" in key: if "snapshot" in key:
return self.data[key] if key in self.data:
return self.data[key]
else:
raise KeyError(f"Key '{key}' is not available.")
# For internal calls we don't want to load the filtered data and use # For internal calls we don't want to load the filtered data and use
# the __ prefixed keys. The internal calls are not being cached. # the __ prefixed keys. The internal calls are not being cached.
@ -451,13 +455,13 @@ class BaseCatalogue(ABC):
else: else:
is_internal = False is_internal = False
try: if key in self.cache_keys():
out = self._cache[key] out = self._cache[key]
if self._load_filtered and not is_internal: if self._load_filtered and not is_internal:
return out[self._filter_mask] return out[self._filter_mask]
else: else:
return out return out
except KeyError: else:
if key == "cartesian_pos": if key == "cartesian_pos":
out = numpy.vstack([self["__x"], self["__y"], self["__z"]]).T out = numpy.vstack([self["__x"], self["__y"], self["__z"]]).T
elif key == "spherical_pos": elif key == "spherical_pos":
@ -487,7 +491,7 @@ class BaseCatalogue(ABC):
elif key == "npart": elif key == "npart":
halomap = self["particle_offset"] halomap = self["particle_offset"]
out = numpy.zeros(len(halomap), dtype=numpy.int32) out = numpy.zeros(len(halomap), dtype=numpy.int32)
for i, hid in enumerate(self["index"]): for i, hid in enumerate(self["__index"]):
if hid == 0: if hid == 0:
continue continue
start, end = halomap[hid] start, end = halomap[hid]
@ -528,7 +532,7 @@ class BaseCatalogue(ABC):
def __len__(self): def __len__(self):
if self._catalogue_length is None: if self._catalogue_length is None:
self._catalogue_length = len(self["index"]) self._catalogue_length = len(self["__index"])
return self._catalogue_length return self._catalogue_length