diff --git a/csiborgtools/read/halo_cat.py b/csiborgtools/read/halo_cat.py index 9674636..3bef45b 100644 --- a/csiborgtools/read/halo_cat.py +++ b/csiborgtools/read/halo_cat.py @@ -38,6 +38,8 @@ from .utils import add_columns, cols_to_structured from ..utils import fprint from .readsim import make_halomap_dict, load_halo_particles +from collections import OrderedDict + ############################################################################### # Base catalogue # @@ -61,7 +63,8 @@ class BaseCatalogue(ABC): _observer_velocity = None _mass_key = None - _cache = {} + _cache = OrderedDict() + _cache_maxsize = None _catalogue_length = None _is_closed = None @@ -77,7 +80,7 @@ class BaseCatalogue(ABC): ] def __init__(self, simname, nsim, nsnap, halo_finder, catalogue_name, - paths, mass_key): + paths, mass_key, cache_maxsize=64): self.simname = simname self.nsim = nsim self.nsnap = nsnap @@ -89,6 +92,7 @@ class BaseCatalogue(ABC): self._data = File(fname, "r") self._is_closed = False + self.cache_maxsize = cache_maxsize self.catalogue_name = catalogue_name @property @@ -168,6 +172,21 @@ class BaseCatalogue(ABC): raise RuntimeError("`data` is not set!") return self._data + @property + def cache_maxsize(self): + """Maximum size of the cache.""" + if self._cache_maxsize is None: + raise RuntimeError("`cache_maxsize` is not set!") + return self._cache_maxsize + + @cache_maxsize.setter + def cache_maxsize(self, cache_maxsize): + assert isinstance(cache_maxsize, int) + self._cache_maxsize = cache_maxsize + + def cache_length(self): + return len(self._cache) + @property def observer_location(self): """Observer location.""" @@ -440,6 +459,7 @@ class BaseCatalogue(ABC): if "snapshot" in key: return self.data[key] + # For non-snapshot keys, we cache the results. try: return self._cache[key] except KeyError: @@ -466,23 +486,24 @@ class BaseCatalogue(ABC): out = numpy.linalg.norm(out - self.observer_location, axis=1) elif key == "angular_momentum": out = numpy.vstack([self["Lx"], self["Ly"], self["Lz"]]).T - elif key in self.data[self.catalogue_name].keys(): - out = self.data[f"{self.catalogue_name}/{key}"][:] elif key == "particle_offset": out = make_halomap_dict(self["snapshot_final/halo_map"][:]) elif key == "npart": halomap = self["particle_offset"] out = numpy.zeros(len(halomap), dtype=numpy.int32) for i, hid in enumerate(self["index"]): - if i == 0: + if hid == 0: continue start, end = halomap[hid] out[i] = end - start + elif key in self.data[self.catalogue_name].keys(): + out = self.data[f"{self.catalogue_name}/{key}"][:] else: raise KeyError(f"Key '{key}' is not available.") - # TODO enforce a maximum size of the dictionary - # ALSO DO THE MASKING HERE? AND IF A NEW MASK THEN RESET CACHE? + # TODO: Enfore the masking somewhere here? + if self.cache_length() > self.cache_maxsize: + self._cache.popitem(last=False) self._cache[key] = out return out @@ -496,7 +517,12 @@ class BaseCatalogue(ABC): if not self._is_closed: self.data.close() self._is_closed = True - self._cache = {} + self._cache.clear() + collect() + + def clear_cache(self): + """Clear the cache dictionary.""" + self._cache.clear() collect() def __len__(self): @@ -533,10 +559,11 @@ class CSiBORGCatalogue(BaseCatalogue): Observer's velocity in :math:`\mathrm{km} / \mathrm{s}`. """ def __init__(self, nsim, paths, catalogue_name, halo_finder, mass_key=None, - observer_velocity=None): + observer_velocity=None, cache_maxsize=64): super().__init__("csiborg", nsim, max(paths.get_snapshots(nsim, "csiborg")), - halo_finder, catalogue_name, paths, mass_key) + halo_finder, catalogue_name, paths, mass_key, + cache_maxsize) self.box = CSiBORGBox(self.nsnap, self.nsim, self.paths) self.observer_location = [338.85, 338.85, 338.85] # Mpc / h self.observer_velocity = observer_velocity