mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-05-20 17:41:13 +00:00
Add ordered caching
This commit is contained in:
parent
afb5ace871
commit
b756a72251
1 changed files with 37 additions and 10 deletions
|
@ -38,6 +38,8 @@ from .utils import add_columns, cols_to_structured
|
|||
from ..utils import fprint
|
||||
from .readsim import make_halomap_dict, load_halo_particles
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Base catalogue #
|
||||
|
@ -61,7 +63,8 @@ class BaseCatalogue(ABC):
|
|||
_observer_velocity = None
|
||||
_mass_key = None
|
||||
|
||||
_cache = {}
|
||||
_cache = OrderedDict()
|
||||
_cache_maxsize = None
|
||||
_catalogue_length = None
|
||||
_is_closed = None
|
||||
|
||||
|
@ -77,7 +80,7 @@ class BaseCatalogue(ABC):
|
|||
]
|
||||
|
||||
def __init__(self, simname, nsim, nsnap, halo_finder, catalogue_name,
|
||||
paths, mass_key):
|
||||
paths, mass_key, cache_maxsize=64):
|
||||
self.simname = simname
|
||||
self.nsim = nsim
|
||||
self.nsnap = nsnap
|
||||
|
@ -89,6 +92,7 @@ class BaseCatalogue(ABC):
|
|||
self._data = File(fname, "r")
|
||||
self._is_closed = False
|
||||
|
||||
self.cache_maxsize = cache_maxsize
|
||||
self.catalogue_name = catalogue_name
|
||||
|
||||
@property
|
||||
|
@ -168,6 +172,21 @@ class BaseCatalogue(ABC):
|
|||
raise RuntimeError("`data` is not set!")
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def cache_maxsize(self):
|
||||
"""Maximum size of the cache."""
|
||||
if self._cache_maxsize is None:
|
||||
raise RuntimeError("`cache_maxsize` is not set!")
|
||||
return self._cache_maxsize
|
||||
|
||||
@cache_maxsize.setter
|
||||
def cache_maxsize(self, cache_maxsize):
|
||||
assert isinstance(cache_maxsize, int)
|
||||
self._cache_maxsize = cache_maxsize
|
||||
|
||||
def cache_length(self):
|
||||
return len(self._cache)
|
||||
|
||||
@property
|
||||
def observer_location(self):
|
||||
"""Observer location."""
|
||||
|
@ -440,6 +459,7 @@ class BaseCatalogue(ABC):
|
|||
if "snapshot" in key:
|
||||
return self.data[key]
|
||||
|
||||
# For non-snapshot keys, we cache the results.
|
||||
try:
|
||||
return self._cache[key]
|
||||
except KeyError:
|
||||
|
@ -466,23 +486,24 @@ class BaseCatalogue(ABC):
|
|||
out = numpy.linalg.norm(out - self.observer_location, axis=1)
|
||||
elif key == "angular_momentum":
|
||||
out = numpy.vstack([self["Lx"], self["Ly"], self["Lz"]]).T
|
||||
elif key in self.data[self.catalogue_name].keys():
|
||||
out = self.data[f"{self.catalogue_name}/{key}"][:]
|
||||
elif key == "particle_offset":
|
||||
out = make_halomap_dict(self["snapshot_final/halo_map"][:])
|
||||
elif key == "npart":
|
||||
halomap = self["particle_offset"]
|
||||
out = numpy.zeros(len(halomap), dtype=numpy.int32)
|
||||
for i, hid in enumerate(self["index"]):
|
||||
if i == 0:
|
||||
if hid == 0:
|
||||
continue
|
||||
start, end = halomap[hid]
|
||||
out[i] = end - start
|
||||
elif key in self.data[self.catalogue_name].keys():
|
||||
out = self.data[f"{self.catalogue_name}/{key}"][:]
|
||||
else:
|
||||
raise KeyError(f"Key '{key}' is not available.")
|
||||
|
||||
# TODO enforce a maximum size of the dictionary
|
||||
# ALSO DO THE MASKING HERE? AND IF A NEW MASK THEN RESET CACHE?
|
||||
# TODO: Enfore the masking somewhere here?
|
||||
if self.cache_length() > self.cache_maxsize:
|
||||
self._cache.popitem(last=False)
|
||||
self._cache[key] = out
|
||||
return out
|
||||
|
||||
|
@ -496,7 +517,12 @@ class BaseCatalogue(ABC):
|
|||
if not self._is_closed:
|
||||
self.data.close()
|
||||
self._is_closed = True
|
||||
self._cache = {}
|
||||
self._cache.clear()
|
||||
collect()
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the cache dictionary."""
|
||||
self._cache.clear()
|
||||
collect()
|
||||
|
||||
def __len__(self):
|
||||
|
@ -533,10 +559,11 @@ class CSiBORGCatalogue(BaseCatalogue):
|
|||
Observer's velocity in :math:`\mathrm{km} / \mathrm{s}`.
|
||||
"""
|
||||
def __init__(self, nsim, paths, catalogue_name, halo_finder, mass_key=None,
|
||||
observer_velocity=None):
|
||||
observer_velocity=None, cache_maxsize=64):
|
||||
super().__init__("csiborg", nsim,
|
||||
max(paths.get_snapshots(nsim, "csiborg")),
|
||||
halo_finder, catalogue_name, paths, mass_key)
|
||||
halo_finder, catalogue_name, paths, mass_key,
|
||||
cache_maxsize)
|
||||
self.box = CSiBORGBox(self.nsnap, self.nsim, self.paths)
|
||||
self.observer_location = [338.85, 338.85, 338.85] # Mpc / h
|
||||
self.observer_velocity = observer_velocity
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue