Add ordered caching

This commit is contained in:
rstiskalek 2023-10-20 10:28:15 +01:00
parent afb5ace871
commit b756a72251

View file

@ -38,6 +38,8 @@ from .utils import add_columns, cols_to_structured
from ..utils import fprint
from .readsim import make_halomap_dict, load_halo_particles
from collections import OrderedDict
###############################################################################
# Base catalogue #
@ -61,7 +63,8 @@ class BaseCatalogue(ABC):
_observer_velocity = None
_mass_key = None
_cache = {}
_cache = OrderedDict()
_cache_maxsize = None
_catalogue_length = None
_is_closed = None
@ -77,7 +80,7 @@ class BaseCatalogue(ABC):
]
def __init__(self, simname, nsim, nsnap, halo_finder, catalogue_name,
paths, mass_key):
paths, mass_key, cache_maxsize=64):
self.simname = simname
self.nsim = nsim
self.nsnap = nsnap
@ -89,6 +92,7 @@ class BaseCatalogue(ABC):
self._data = File(fname, "r")
self._is_closed = False
self.cache_maxsize = cache_maxsize
self.catalogue_name = catalogue_name
@property
@ -168,6 +172,21 @@ class BaseCatalogue(ABC):
raise RuntimeError("`data` is not set!")
return self._data
@property
def cache_maxsize(self):
"""Maximum size of the cache."""
if self._cache_maxsize is None:
raise RuntimeError("`cache_maxsize` is not set!")
return self._cache_maxsize
@cache_maxsize.setter
def cache_maxsize(self, cache_maxsize):
assert isinstance(cache_maxsize, int)
self._cache_maxsize = cache_maxsize
def cache_length(self):
return len(self._cache)
@property
def observer_location(self):
"""Observer location."""
@ -440,6 +459,7 @@ class BaseCatalogue(ABC):
if "snapshot" in key:
return self.data[key]
# For non-snapshot keys, we cache the results.
try:
return self._cache[key]
except KeyError:
@ -466,23 +486,24 @@ class BaseCatalogue(ABC):
out = numpy.linalg.norm(out - self.observer_location, axis=1)
elif key == "angular_momentum":
out = numpy.vstack([self["Lx"], self["Ly"], self["Lz"]]).T
elif key in self.data[self.catalogue_name].keys():
out = self.data[f"{self.catalogue_name}/{key}"][:]
elif key == "particle_offset":
out = make_halomap_dict(self["snapshot_final/halo_map"][:])
elif key == "npart":
halomap = self["particle_offset"]
out = numpy.zeros(len(halomap), dtype=numpy.int32)
for i, hid in enumerate(self["index"]):
if i == 0:
if hid == 0:
continue
start, end = halomap[hid]
out[i] = end - start
elif key in self.data[self.catalogue_name].keys():
out = self.data[f"{self.catalogue_name}/{key}"][:]
else:
raise KeyError(f"Key '{key}' is not available.")
# TODO enforce a maximum size of the dictionary
# ALSO DO THE MASKING HERE? AND IF A NEW MASK THEN RESET CACHE?
# TODO: Enfore the masking somewhere here?
if self.cache_length() > self.cache_maxsize:
self._cache.popitem(last=False)
self._cache[key] = out
return out
@ -496,7 +517,12 @@ class BaseCatalogue(ABC):
if not self._is_closed:
self.data.close()
self._is_closed = True
self._cache = {}
self._cache.clear()
collect()
def clear_cache(self):
"""Clear the cache dictionary."""
self._cache.clear()
collect()
def __len__(self):
@ -533,10 +559,11 @@ class CSiBORGCatalogue(BaseCatalogue):
Observer's velocity in :math:`\mathrm{km} / \mathrm{s}`.
"""
def __init__(self, nsim, paths, catalogue_name, halo_finder, mass_key=None,
observer_velocity=None):
observer_velocity=None, cache_maxsize=64):
super().__init__("csiborg", nsim,
max(paths.get_snapshots(nsim, "csiborg")),
halo_finder, catalogue_name, paths, mass_key)
halo_finder, catalogue_name, paths, mass_key,
cache_maxsize)
self.box = CSiBORGBox(self.nsnap, self.nsim, self.paths)
self.observer_location = [338.85, 338.85, 338.85] # Mpc / h
self.observer_velocity = observer_velocity