mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 18:08:03 +00:00
Quijote kNN adding (#62)
* Fix small bug * Add fiducial observers * Rename 1D knn * Add new bounds system * rm whitespace * Add boudns * Add simname to paths * Add fiducial obserevrs * apply bounds only if not none * Add TODO * add simnames * update script * Fix distance bug * update yaml * Update file reading * Update gitignore * Add plots * add check if empty list * add func to obtaining cross * Update nb * Remove blank lines * update ignroes * loop over a few ics * update gitignore * add comments
This commit is contained in:
parent
7971fe2bc1
commit
255bec9710
16 changed files with 635 additions and 231 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -18,3 +18,7 @@ scripts/plot_correlation.ipynb
|
||||||
scripts/*.sh
|
scripts/*.sh
|
||||||
venv/
|
venv/
|
||||||
.trunk/*
|
.trunk/*
|
||||||
|
scripts_test/
|
||||||
|
scripts_plots/python.sh
|
||||||
|
scripts_plots/submit.sh
|
||||||
|
scripts_plots/*.out
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
from csiborgtools.clustering.knn import kNN_CDF # noqa
|
from csiborgtools.clustering.knn import kNN_1DCDF # noqa
|
||||||
from csiborgtools.clustering.utils import (BaseRVS, RVSinbox, # noqa
|
from csiborgtools.clustering.utils import (BaseRVS, RVSinbox, # noqa
|
||||||
RVSinsphere, RVSonsphere,
|
RVSinsphere, RVSonsphere,
|
||||||
normalised_marks)
|
normalised_marks)
|
||||||
|
|
|
@ -22,8 +22,10 @@ from scipy.stats import binned_statistic
|
||||||
from .utils import BaseRVS
|
from .utils import BaseRVS
|
||||||
|
|
||||||
|
|
||||||
class kNN_CDF:
|
class kNN_1DCDF:
|
||||||
"""Object to calculate the kNN-CDF statistic."""
|
"""
|
||||||
|
Object to calculate the 1-dimensional kNN-CDF statistic.
|
||||||
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cdf_from_samples(r, rmin=None, rmax=None, neval=None,
|
def cdf_from_samples(r, rmin=None, rmax=None, neval=None,
|
||||||
dtype=numpy.float32):
|
dtype=numpy.float32):
|
||||||
|
|
|
@ -13,12 +13,13 @@
|
||||||
# with this program; if not, write to the Free Software Foundation, Inc.,
|
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||||
from .box_units import CSiBORGBox, QuijoteBox # noqa
|
from .box_units import CSiBORGBox, QuijoteBox # noqa
|
||||||
from .halo_cat import ClumpsCatalogue, HaloCatalogue, QuijoteHaloCatalogue # noqa
|
from .halo_cat import (ClumpsCatalogue, HaloCatalogue, # noqa
|
||||||
|
QuijoteHaloCatalogue, fiducial_observers)
|
||||||
from .knn_summary import kNNCDFReader # noqa
|
from .knn_summary import kNNCDFReader # noqa
|
||||||
from .obs import (SDSS, MCXCClusters, PlanckClusters, TwoMPPGalaxies, # noqa
|
from .obs import (SDSS, MCXCClusters, PlanckClusters, TwoMPPGalaxies, # noqa
|
||||||
TwoMPPGroups)
|
TwoMPPGroups)
|
||||||
from .overlap_summary import (NPairsOverlap, PairOverlap, # noqa
|
from .overlap_summary import (NPairsOverlap, PairOverlap, # noqa
|
||||||
binned_resample_mean)
|
binned_resample_mean, get_cross_sims)
|
||||||
from .paths import Paths # noqa
|
from .paths import Paths # noqa
|
||||||
from .pk_summary import PKReader # noqa
|
from .pk_summary import PKReader # noqa
|
||||||
from .readsim import (MmainReader, ParticleReader, halfwidth_mask, # noqa
|
from .readsim import (MmainReader, ParticleReader, halfwidth_mask, # noqa
|
||||||
|
|
|
@ -18,9 +18,14 @@ Simulation catalogues:
|
||||||
- Quijote: halo catalogue.
|
- Quijote: halo catalogue.
|
||||||
"""
|
"""
|
||||||
from abc import ABC, abstractproperty
|
from abc import ABC, abstractproperty
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import lru_cache
|
||||||
|
from itertools import product
|
||||||
|
from math import floor
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from readfof import FoF_catalog
|
from readfof import FoF_catalog
|
||||||
from sklearn.neighbors import NearestNeighbors
|
from sklearn.neighbors import NearestNeighbors
|
||||||
|
|
||||||
|
@ -98,6 +103,16 @@ class BaseCatalogue(ABC):
|
||||||
raise RuntimeError("Catalogue data not loaded!")
|
raise RuntimeError("Catalogue data not loaded!")
|
||||||
return self._data
|
return self._data
|
||||||
|
|
||||||
|
def apply_bounds(self, bounds):
|
||||||
|
for key, (xmin, xmax) in bounds.items():
|
||||||
|
xmin = -numpy.inf if xmin is None else xmin
|
||||||
|
xmax = numpy.inf if xmax is None else xmax
|
||||||
|
if key == "dist":
|
||||||
|
x = self.radial_distance(in_initial=False)
|
||||||
|
else:
|
||||||
|
x = self[key]
|
||||||
|
self._data = self._data[(x > xmin) & (x <= xmax)]
|
||||||
|
|
||||||
@abstractproperty
|
@abstractproperty
|
||||||
def box(self):
|
def box(self):
|
||||||
"""
|
"""
|
||||||
|
@ -175,6 +190,22 @@ class BaseCatalogue(ABC):
|
||||||
rsp = cartesian_to_radec(rsp)
|
rsp = cartesian_to_radec(rsp)
|
||||||
return rsp
|
return rsp
|
||||||
|
|
||||||
|
def radial_distance(self, in_initial=False):
|
||||||
|
r"""
|
||||||
|
Distance of haloes from the origin.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_initial : bool, optional
|
||||||
|
Whether to calculate in the initial snapshot.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
radial_distance : 1-dimensional array of shape `(nobjects,)`
|
||||||
|
"""
|
||||||
|
pos = self.position(in_initial=in_initial, cartesian=True)
|
||||||
|
return numpy.linalg.norm(pos, axis=1)
|
||||||
|
|
||||||
def angmomentum(self):
|
def angmomentum(self):
|
||||||
"""
|
"""
|
||||||
Cartesian angular momentum components of halos in the box coordinate
|
Cartesian angular momentum components of halos in the box coordinate
|
||||||
|
@ -186,9 +217,10 @@ class BaseCatalogue(ABC):
|
||||||
"""
|
"""
|
||||||
return numpy.vstack([self["L{}".format(p)] for p in ("x", "y", "z")]).T
|
return numpy.vstack([self["L{}".format(p)] for p in ("x", "y", "z")]).T
|
||||||
|
|
||||||
|
@lru_cache(maxsize=2)
|
||||||
def knn(self, in_initial):
|
def knn(self, in_initial):
|
||||||
"""
|
"""
|
||||||
kNN object fitted on all catalogue objects.
|
kNN object fitted on all catalogue objects. Caches the kNN object.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
@ -202,19 +234,29 @@ class BaseCatalogue(ABC):
|
||||||
knn = NearestNeighbors()
|
knn = NearestNeighbors()
|
||||||
return knn.fit(self.position(in_initial=in_initial))
|
return knn.fit(self.position(in_initial=in_initial))
|
||||||
|
|
||||||
def radius_neigbours(self, X, radius, in_initial):
|
def nearest_neighbours(self, X, radius, in_initial, knearest=False,
|
||||||
|
return_mass=False, masss_key=None):
|
||||||
r"""
|
r"""
|
||||||
Sorted nearest neigbours within `radius` of `X` in the initial
|
Sorted nearest neigbours within `radius` of `X` in the initial or final
|
||||||
or final snapshot.
|
snapshot. However, if `knearest` is `True` then the `radius` is assumed
|
||||||
|
to be the integer number of nearest neighbours to return.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
X : 2-dimensional array of shape `(n_queries, 3)`
|
X : 2-dimensional array of shape `(n_queries, 3)`
|
||||||
Cartesian query position components in :math:`\mathrm{cMpc}`.
|
Cartesian query position components in :math:`\mathrm{cMpc}`.
|
||||||
radius : float
|
radius : float or int
|
||||||
Limiting neighbour distance.
|
Limiting neighbour distance. If `knearest` is `True` then this is
|
||||||
|
the number of nearest neighbours to return.
|
||||||
in_initial : bool
|
in_initial : bool
|
||||||
Whether to define the kNN on the initial or final snapshot.
|
Whether to define the kNN on the initial or final snapshot.
|
||||||
|
knearest : bool, optional
|
||||||
|
Whether `radius` is the number of nearest neighbours to return.
|
||||||
|
return_mass : bool, optional
|
||||||
|
Whether to return the masses of the nearest neighbours.
|
||||||
|
masss_key : str, optional
|
||||||
|
Key of the mass column in the catalogue. Must be provided if
|
||||||
|
`return_mass` is `True`.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -227,8 +269,30 @@ class BaseCatalogue(ABC):
|
||||||
"""
|
"""
|
||||||
if not (X.ndim == 2 and X.shape[1] == 3):
|
if not (X.ndim == 2 and X.shape[1] == 3):
|
||||||
raise TypeError("`X` must be an array of shape `(n_samples, 3)`.")
|
raise TypeError("`X` must be an array of shape `(n_samples, 3)`.")
|
||||||
|
if knearest:
|
||||||
|
assert isinstance(radius, int)
|
||||||
|
if return_mass:
|
||||||
|
assert masss_key is not None
|
||||||
knn = self.knn(in_initial)
|
knn = self.knn(in_initial)
|
||||||
return knn.radius_neighbors(X, radius, sort_results=True)
|
|
||||||
|
if knearest:
|
||||||
|
dist, indxs = knn.kneighbors(X, radius)
|
||||||
|
else:
|
||||||
|
dist, indxs = knn.radius_neighbors(X, radius, sort_results=True)
|
||||||
|
|
||||||
|
if not return_mass:
|
||||||
|
return dist, indxs
|
||||||
|
|
||||||
|
if knearest:
|
||||||
|
mass = numpy.copy(dist)
|
||||||
|
for i in range(dist.shape[0]):
|
||||||
|
mass[i, :] = self[masss_key][indxs[i]]
|
||||||
|
else:
|
||||||
|
mass = deepcopy(dist)
|
||||||
|
for i in range(dist.size):
|
||||||
|
mass[i] = self[masss_key][indxs[i]]
|
||||||
|
|
||||||
|
return dist, indxs, mass
|
||||||
|
|
||||||
def angular_neighbours(self, X, ang_radius, in_rsp, rad_tolerance=None):
|
def angular_neighbours(self, X, ang_radius, in_rsp, rad_tolerance=None):
|
||||||
r"""
|
r"""
|
||||||
|
@ -354,13 +418,11 @@ class ClumpsCatalogue(BaseCSiBORG):
|
||||||
IC realisation index.
|
IC realisation index.
|
||||||
paths : py:class`csiborgtools.read.Paths`
|
paths : py:class`csiborgtools.read.Paths`
|
||||||
Paths object.
|
Paths object.
|
||||||
maxdist : float, optional
|
bounds : dict
|
||||||
The maximum comoving distance of a halo. By default
|
Parameter bounds to apply to the catalogue. The keys are the parameter
|
||||||
:math:`155.5 / 0.705 ~ \mathrm{Mpc}` with assumed :math:`h = 0.705`,
|
names and the items are a len-2 tuple of (min, max) values. In case of
|
||||||
which corresponds to the high-resolution region.
|
no minimum or maximum, use `None`. For radial distance from the origin
|
||||||
minmass : len-2 tuple, optional
|
use `dist`.
|
||||||
Minimum mass. The first element is the catalogue key and the second is
|
|
||||||
the value.
|
|
||||||
load_fitted : bool, optional
|
load_fitted : bool, optional
|
||||||
Whether to load fitted quantities.
|
Whether to load fitted quantities.
|
||||||
rawdata : bool, optional
|
rawdata : bool, optional
|
||||||
|
@ -368,8 +430,8 @@ class ClumpsCatalogue(BaseCSiBORG):
|
||||||
transformations.
|
transformations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, nsim, paths, maxdist=155.5 / 0.705,
|
def __init__(self, nsim, paths, bounds={"dist": (0, 155.5 / 0.705)},
|
||||||
minmass=("mass_cl", 1e12), load_fitted=True, rawdata=False):
|
load_fitted=True, rawdata=False):
|
||||||
self.nsim = nsim
|
self.nsim = nsim
|
||||||
self.paths = paths
|
self.paths = paths
|
||||||
# Read in the clumps from the final snapshot
|
# Read in the clumps from the final snapshot
|
||||||
|
@ -396,12 +458,8 @@ class ClumpsCatalogue(BaseCSiBORG):
|
||||||
"r500c", "m200c", "m500c", "r200m", "m200m",
|
"r500c", "m200c", "m500c", "r200m", "m200m",
|
||||||
"vx", "vy", "vz"]
|
"vx", "vy", "vz"]
|
||||||
self._data = self.box.convert_from_box(self._data, names)
|
self._data = self.box.convert_from_box(self._data, names)
|
||||||
if maxdist is not None:
|
if bounds is not None:
|
||||||
dist = numpy.sqrt(self._data["x"]**2 + self._data["y"]**2
|
self.apply_bounds(bounds)
|
||||||
+ self._data["z"]**2)
|
|
||||||
self._data = self._data[dist < maxdist]
|
|
||||||
if minmass is not None:
|
|
||||||
self._data = self._data[self._data[minmass[0]] > minmass[1]]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ismain(self):
|
def ismain(self):
|
||||||
|
@ -431,13 +489,11 @@ class HaloCatalogue(BaseCSiBORG):
|
||||||
IC realisation index.
|
IC realisation index.
|
||||||
paths : py:class`csiborgtools.read.Paths`
|
paths : py:class`csiborgtools.read.Paths`
|
||||||
Paths object.
|
Paths object.
|
||||||
maxdist : float, optional
|
bounds : dict
|
||||||
The maximum comoving distance of a halo. By default
|
Parameter bounds to apply to the catalogue. The keys are the parameter
|
||||||
:math:`155.5 / 0.705 ~ \mathrm{Mpc}` with assumed :math:`h = 0.705`,
|
names and the items are a len-2 tuple of (min, max) values. In case of
|
||||||
which corresponds to the high-resolution region.
|
no minimum or maximum, use `None`. For radial distance from the origin
|
||||||
minmass : len-2 tuple
|
use `dist`.
|
||||||
Minimum mass. The first element is the catalogue key and the second is
|
|
||||||
the value.
|
|
||||||
with_lagpatch : bool, optional
|
with_lagpatch : bool, optional
|
||||||
Whether to only load halos with a resolved Lagrangian patch.
|
Whether to only load halos with a resolved Lagrangian patch.
|
||||||
load_fitted : bool, optional
|
load_fitted : bool, optional
|
||||||
|
@ -450,7 +506,7 @@ class HaloCatalogue(BaseCSiBORG):
|
||||||
"""
|
"""
|
||||||
_clumps_cat = None
|
_clumps_cat = None
|
||||||
|
|
||||||
def __init__(self, nsim, paths, maxdist=155.5 / 0.705, minmass=("M", 1e12),
|
def __init__(self, nsim, paths, bounds={"dist": (0, 155.5 / 0.705)},
|
||||||
with_lagpatch=True, load_fitted=True, load_initial=True,
|
with_lagpatch=True, load_fitted=True, load_initial=True,
|
||||||
load_clumps_cat=False, rawdata=False):
|
load_clumps_cat=False, rawdata=False):
|
||||||
self.nsim = nsim
|
self.nsim = nsim
|
||||||
|
@ -498,12 +554,8 @@ class HaloCatalogue(BaseCSiBORG):
|
||||||
names = ["x0", "y0", "z0", "lagpatch"]
|
names = ["x0", "y0", "z0", "lagpatch"]
|
||||||
self._data = self.box.convert_from_box(self._data, names)
|
self._data = self.box.convert_from_box(self._data, names)
|
||||||
|
|
||||||
if maxdist is not None:
|
if bounds is not None:
|
||||||
dist = numpy.sqrt(self._data["x"]**2 + self._data["y"]**2
|
self.apply_bounds(bounds)
|
||||||
+ self._data["z"]**2)
|
|
||||||
self._data = self._data[dist < maxdist]
|
|
||||||
if minmass is not None:
|
|
||||||
self._data = self._data[self._data[minmass[0]] > minmass[1]]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def clumps_cat(self):
|
def clumps_cat(self):
|
||||||
|
@ -538,16 +590,13 @@ class QuijoteHaloCatalogue(BaseCatalogue):
|
||||||
Snapshot index.
|
Snapshot index.
|
||||||
origin : len-3 tuple, optional
|
origin : len-3 tuple, optional
|
||||||
Where to place the origin of the box. By default the centre of the box.
|
Where to place the origin of the box. By default the centre of the box.
|
||||||
In units of :math:`cMpc`.
|
In units of :math:`cMpc`. Optionally can be an integer between 0 and 8,
|
||||||
maxdist : float, optional
|
inclusive to correspond to CSiBORG boxes.
|
||||||
The maximum comoving distance of a halo in the new reference frame, in
|
bounds : dict
|
||||||
units of :math:`cMpc`.
|
Parameter bounds to apply to the catalogue. The keys are the parameter
|
||||||
minmass : len-2 tuple
|
names and the items are a len-2 tuple of (min, max) values. In case of
|
||||||
Minimum mass. The first element is the catalogue key and the second is
|
no minimum or maximum, use `None`. For radial distance from the origin
|
||||||
the value.
|
use `dist`.
|
||||||
rawdata : bool, optional
|
|
||||||
Whether to return the raw data. In this case applies no cuts and
|
|
||||||
transformations.
|
|
||||||
**kwargs : dict
|
**kwargs : dict
|
||||||
Keyword arguments for backward compatibility.
|
Keyword arguments for backward compatibility.
|
||||||
"""
|
"""
|
||||||
|
@ -555,8 +604,7 @@ class QuijoteHaloCatalogue(BaseCatalogue):
|
||||||
|
|
||||||
def __init__(self, nsim, paths, nsnap,
|
def __init__(self, nsim, paths, nsnap,
|
||||||
origin=[500 / 0.6711, 500 / 0.6711, 500 / 0.6711],
|
origin=[500 / 0.6711, 500 / 0.6711, 500 / 0.6711],
|
||||||
maxdist=None, minmass=("group_mass", 1e12), rawdata=False,
|
bounds=None, **kwargs):
|
||||||
**kwargs):
|
|
||||||
self.paths = paths
|
self.paths = paths
|
||||||
self.nsnap = nsnap
|
self.nsnap = nsnap
|
||||||
fpath = join(self.paths.quijote_dir, "halos", str(nsim))
|
fpath = join(self.paths.quijote_dir, "halos", str(nsim))
|
||||||
|
@ -569,9 +617,12 @@ class QuijoteHaloCatalogue(BaseCatalogue):
|
||||||
("group_mass", numpy.float32), ("npart", numpy.int32)]
|
("group_mass", numpy.float32), ("npart", numpy.int32)]
|
||||||
data = cols_to_structured(fof.GroupLen.size, cols)
|
data = cols_to_structured(fof.GroupLen.size, cols)
|
||||||
|
|
||||||
|
if isinstance(origin, int):
|
||||||
|
origin = fiducial_observers(1000 / 0.6711, 155.5 / 0.6711)[origin]
|
||||||
|
|
||||||
pos = fof.GroupPos / 1e3 / self.box.h
|
pos = fof.GroupPos / 1e3 / self.box.h
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
pos -= origin[i]
|
pos[:, i] -= origin[i]
|
||||||
vel = fof.GroupVel * (1 + self.redshift)
|
vel = fof.GroupVel * (1 + self.redshift)
|
||||||
for i, p in enumerate(["x", "y", "z"]):
|
for i, p in enumerate(["x", "y", "z"]):
|
||||||
data[p] = pos[:, i]
|
data[p] = pos[:, i]
|
||||||
|
@ -579,14 +630,9 @@ class QuijoteHaloCatalogue(BaseCatalogue):
|
||||||
data["group_mass"] = fof.GroupMass * 1e10 / self.box.h
|
data["group_mass"] = fof.GroupMass * 1e10 / self.box.h
|
||||||
data["npart"] = fof.GroupLen
|
data["npart"] = fof.GroupLen
|
||||||
|
|
||||||
if not rawdata:
|
|
||||||
if maxdist is not None:
|
|
||||||
pos = numpy.vstack([data["x"], data["y"], data["z"]]).T
|
|
||||||
data = data[numpy.linalg.norm(pos, axis=1) < maxdist]
|
|
||||||
if minmass is not None:
|
|
||||||
data = data[data[minmass[0]] > minmass[1]]
|
|
||||||
|
|
||||||
self._data = data
|
self._data = data
|
||||||
|
if bounds is not None:
|
||||||
|
self.apply_bounds(bounds)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nsnap(self):
|
def nsnap(self):
|
||||||
|
@ -626,3 +672,35 @@ class QuijoteHaloCatalogue(BaseCatalogue):
|
||||||
box : instance of :py:class:`csiborgtools.units.BaseBox`
|
box : instance of :py:class:`csiborgtools.units.BaseBox`
|
||||||
"""
|
"""
|
||||||
return QuijoteBox(self.nsnap)
|
return QuijoteBox(self.nsnap)
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Utility functions for halo catalogues #
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def fiducial_observers(boxwidth, radius):
|
||||||
|
"""
|
||||||
|
Positions of fiducial observers in a box, such that that the box is
|
||||||
|
subdivided among them into spherical regions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
boxwidth : float
|
||||||
|
Box width.
|
||||||
|
radius : float
|
||||||
|
Radius of the spherical regions.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
origins : list of len-3 lists
|
||||||
|
Positions of the observers.
|
||||||
|
"""
|
||||||
|
nobs = floor(boxwidth / (2 * radius)) # Number of observers per dimension
|
||||||
|
|
||||||
|
origins = list(product([1, 3, 5], repeat=nobs))
|
||||||
|
for i in range(len(origins)):
|
||||||
|
origins[i] = list(origins[i])
|
||||||
|
for j in range(nobs):
|
||||||
|
origins[i][j] *= radius
|
||||||
|
return origins
|
||||||
|
|
|
@ -246,7 +246,8 @@ class PairOverlap:
|
||||||
prob_nomatch : 1-dimensional array of shape `(nhalos, )`
|
prob_nomatch : 1-dimensional array of shape `(nhalos, )`
|
||||||
"""
|
"""
|
||||||
overlap = self.overlap(from_smoothed)
|
overlap = self.overlap(from_smoothed)
|
||||||
return numpy.array([numpy.product(1 - overlap) for overlap in overlap])
|
return numpy.array([numpy.product(numpy.subtract(1, cross))
|
||||||
|
for cross in overlap])
|
||||||
|
|
||||||
def dist(self, in_initial, norm_kind=None):
|
def dist(self, in_initial, norm_kind=None):
|
||||||
"""
|
"""
|
||||||
|
@ -612,6 +613,31 @@ class NPairsOverlap:
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def get_cross_sims(nsim0, paths, smoothed):
|
||||||
|
"""
|
||||||
|
Get the list of cross simulations for a given reference simulation for
|
||||||
|
which the overlap has been calculated.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
nsim0 : int
|
||||||
|
Reference simulation number.
|
||||||
|
paths : :py:class:`csiborgtools.paths.Paths`
|
||||||
|
Paths object.
|
||||||
|
smoothed : bool
|
||||||
|
Whether to use the smoothed overlap or not.
|
||||||
|
"""
|
||||||
|
nsimxs = []
|
||||||
|
for nsimx in paths.get_ics():
|
||||||
|
if nsimx == nsim0:
|
||||||
|
continue
|
||||||
|
f1 = paths.overlap_path(nsim0, nsimx, smoothed)
|
||||||
|
f2 = paths.overlap_path(nsimx, nsim0, smoothed)
|
||||||
|
if isfile(f1) or isfile(f2):
|
||||||
|
nsimxs.append(nsimx)
|
||||||
|
return nsimxs
|
||||||
|
|
||||||
|
|
||||||
def binned_resample_mean(x, y, prob, bins, nresample=50, seed=42):
|
def binned_resample_mean(x, y, prob, bins, nresample=50, seed=42):
|
||||||
"""
|
"""
|
||||||
Calculate binned average of `y` by MC resampling. Each point is kept with
|
Calculate binned average of `y` by MC resampling. Each point is kept with
|
||||||
|
|
|
@ -88,6 +88,13 @@ class Paths:
|
||||||
self._check_directory(path)
|
self._check_directory(path)
|
||||||
self._quijote_dir = path
|
self._quijote_dir = path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_quijote_ics():
|
||||||
|
"""
|
||||||
|
Quijote IC realisation IDs.
|
||||||
|
"""
|
||||||
|
return numpy.arange(100, dtype=int)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def postdir(self):
|
def postdir(self):
|
||||||
"""
|
"""
|
||||||
|
@ -376,40 +383,52 @@ class Paths:
|
||||||
fname = f"{kind}_{MAS}_{str(nsim).zfill(5)}_grid{grid}.npy"
|
fname = f"{kind}_{MAS}_{str(nsim).zfill(5)}_grid{grid}.npy"
|
||||||
return join(fdir, fname)
|
return join(fdir, fname)
|
||||||
|
|
||||||
def knnauto_path(self, run, nsim=None):
|
def knnauto_path(self, simname, run, nsim=None, nobs=None):
|
||||||
"""
|
"""
|
||||||
Path to the `knn` auto-correlation files. If `nsim` is not specified
|
Path to the `knn` auto-correlation files. If `nsim` is not specified
|
||||||
returns a list of files for this run for all available simulations.
|
returns a list of files for this run for all available simulations.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
simname : str
|
||||||
|
Simulation name. Must be either `csiborg` or `quijote`.
|
||||||
run : str
|
run : str
|
||||||
Type of run.
|
Type of run.
|
||||||
nsim : int, optional
|
nsim : int, optional
|
||||||
IC realisation index.
|
IC realisation index.
|
||||||
|
nobs : int, optional
|
||||||
|
Fiducial observer index in Quijote simulations.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
path : str
|
path : str
|
||||||
"""
|
"""
|
||||||
|
assert simname in ["csiborg", "quijote"]
|
||||||
fdir = join(self.postdir, "knn", "auto")
|
fdir = join(self.postdir, "knn", "auto")
|
||||||
if not isdir(fdir):
|
if not isdir(fdir):
|
||||||
makedirs(fdir)
|
makedirs(fdir)
|
||||||
warn(f"Created directory `{fdir}`.", UserWarning, stacklevel=1)
|
warn(f"Created directory `{fdir}`.", UserWarning, stacklevel=1)
|
||||||
if nsim is not None:
|
if nsim is not None:
|
||||||
return join(fdir, f"knncdf_{str(nsim).zfill(5)}_{run}.p")
|
if simname == "csiborg":
|
||||||
|
nsim = str(nsim).zfill(5)
|
||||||
|
else:
|
||||||
|
assert nobs is not None
|
||||||
|
nsim = f"{str(nobs).zfill(2)}{str(nsim).zfill(3)}"
|
||||||
|
return join(fdir, f"{simname}_knncdf_{nsim}_{run}.p")
|
||||||
|
|
||||||
files = glob(join(fdir, "knncdf*"))
|
files = glob(join(fdir, f"{simname}_knncdf*"))
|
||||||
run = "__" + run
|
run = "__" + run
|
||||||
return [f for f in files if run in f]
|
return [f for f in files if run in f]
|
||||||
|
|
||||||
def knncross_path(self, run, nsims=None):
|
def knncross_path(self, simname, run, nsims=None):
|
||||||
"""
|
"""
|
||||||
Path to the `knn` cross-correlation files. If `nsims` is not specified
|
Path to the `knn` cross-correlation files. If `nsims` is not specified
|
||||||
returns a list of files for this run for all available simulations.
|
returns a list of files for this run for all available simulations.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
simname : str
|
||||||
|
Simulation name. Must be either `csiborg` or `quijote`.
|
||||||
run : str
|
run : str
|
||||||
Type of run.
|
Type of run.
|
||||||
nsims : len-2 tuple of int, optional
|
nsims : len-2 tuple of int, optional
|
||||||
|
@ -427,19 +446,21 @@ class Paths:
|
||||||
assert isinstance(nsims, (list, tuple)) and len(nsims) == 2
|
assert isinstance(nsims, (list, tuple)) and len(nsims) == 2
|
||||||
nsim0 = str(nsims[0]).zfill(5)
|
nsim0 = str(nsims[0]).zfill(5)
|
||||||
nsimx = str(nsims[1]).zfill(5)
|
nsimx = str(nsims[1]).zfill(5)
|
||||||
return join(fdir, f"knncdf_{nsim0}_{nsimx}__{run}.p")
|
return join(fdir, f"{simname}_knncdf_{nsim0}_{nsimx}__{run}.p")
|
||||||
|
|
||||||
files = glob(join(fdir, "knncdf*"))
|
files = glob(join(fdir, f"{simname}_knncdf*"))
|
||||||
run = "__" + run
|
run = "__" + run
|
||||||
return [f for f in files if run in f]
|
return [f for f in files if run in f]
|
||||||
|
|
||||||
def tpcfauto_path(self, run, nsim=None):
|
def tpcfauto_path(self, simname, run, nsim=None):
|
||||||
"""
|
"""
|
||||||
Path to the `tpcf` auto-correlation files. If `nsim` is not specified
|
Path to the `tpcf` auto-correlation files. If `nsim` is not specified
|
||||||
returns a list of files for this run for all available simulations.
|
returns a list of files for this run for all available simulations.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
|
simname : str
|
||||||
|
Simulation name. Must be either `csiborg` or `quijote`.
|
||||||
run : str
|
run : str
|
||||||
Type of run.
|
Type of run.
|
||||||
nsim : int, optional
|
nsim : int, optional
|
||||||
|
@ -454,8 +475,8 @@ class Paths:
|
||||||
makedirs(fdir)
|
makedirs(fdir)
|
||||||
warn(f"Created directory `{fdir}`.", UserWarning, stacklevel=1)
|
warn(f"Created directory `{fdir}`.", UserWarning, stacklevel=1)
|
||||||
if nsim is not None:
|
if nsim is not None:
|
||||||
return join(fdir, f"tpcf{str(nsim).zfill(5)}_{run}.p")
|
return join(fdir, f"{simname}_tpcf{str(nsim).zfill(5)}_{run}.p")
|
||||||
|
|
||||||
files = glob(join(fdir, "tpcf*"))
|
files = glob(join(fdir, f"{simname}_tpcf*"))
|
||||||
run = "__" + run
|
run = "__" + run
|
||||||
return [f for f in files if run in f]
|
return [f for f in files if run in f]
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -43,58 +43,76 @@ nproc = comm.Get_size()
|
||||||
|
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--runs", type=str, nargs="+")
|
parser.add_argument("--runs", type=str, nargs="+")
|
||||||
|
parser.add_argument("--ics", type=int, nargs="+", default=None,
|
||||||
|
help="IC realisations. If `-1` processes all simulations.")
|
||||||
|
parser.add_argument("--simname", type=str, choices=["csiborg", "quijote"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
with open("../scripts/knn_auto.yml", "r") as file:
|
with open("../scripts/cluster_knn_auto.yml", "r") as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
Rmax = 155 / 0.705 # Mpc (h = 0.705) high resolution region radius
|
Rmax = 155 / 0.705 # Mpc (h = 0.705) high resolution region radius
|
||||||
totvol = 4 * numpy.pi * Rmax**3 / 3
|
totvol = 4 * numpy.pi * Rmax**3 / 3
|
||||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
ics = paths.get_ics()
|
knncdf = csiborgtools.clustering.kNN_1DCDF()
|
||||||
knncdf = csiborgtools.clustering.kNN_CDF()
|
|
||||||
|
if args.ics is None or args.ics[0] == -1:
|
||||||
|
if args.simname == "csiborg":
|
||||||
|
ics = paths.get_ics()
|
||||||
|
else:
|
||||||
|
ics = paths.get_quijote_ics()
|
||||||
|
else:
|
||||||
|
ics = args.ics
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Analysis #
|
# Analysis #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
def read_single(selection, cat):
|
def read_single(nsim, selection, nobs=None):
|
||||||
"""Positions for single catalogue auto-correlation."""
|
# We first read the full catalogue without applying any bounds.
|
||||||
mmask = numpy.ones(len(cat), dtype=bool)
|
if args.simname == "csiborg":
|
||||||
pos = cat.positions(False)
|
cat = csiborgtools.read.HaloCatalogue(nsim, paths)
|
||||||
# Primary selection
|
else:
|
||||||
psel = selection["primary"]
|
cat = csiborgtools.read.QuijoteHaloCatalogue(nsim, paths, nsnap=4,
|
||||||
pmin, pmax = psel.get("min", None), psel.get("max", None)
|
origin=nobs)
|
||||||
if pmin is not None:
|
|
||||||
mmask &= cat[psel["name"]] >= pmin
|
|
||||||
if pmax is not None:
|
|
||||||
mmask &= cat[psel["name"]] < pmax
|
|
||||||
pos = pos[mmask, ...]
|
|
||||||
|
|
||||||
# Secondary selection
|
cat.apply_bounds({"dist": (0, Rmax)})
|
||||||
if "secondary" not in selection:
|
# We then first read off the primary selection bounds.
|
||||||
return pos
|
sel = selection["primary"]
|
||||||
smask = numpy.ones(pos.shape[0], dtype=bool)
|
pname = None
|
||||||
ssel = selection["secondary"]
|
xs = sel["names"] if isinstance(sel["names"], list) else [sel["names"]]
|
||||||
smin, smax = ssel.get("min", None), ssel.get("max", None)
|
for _name in xs:
|
||||||
prop = cat[ssel["name"]][mmask]
|
if _name in cat.keys:
|
||||||
if ssel.get("toperm", False):
|
pname = _name
|
||||||
prop = numpy.random.permutation(prop)
|
if pname is None:
|
||||||
if ssel.get("marked", True):
|
raise KeyError(f"Invalid names `{sel['name']}`.")
|
||||||
x = cat[psel["name"]][mmask]
|
|
||||||
prop = csiborgtools.clustering.normalised_marks(
|
|
||||||
x, prop, nbins=config["nbins_marks"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if smin is not None:
|
cat.apply_bounds({pname: (sel.get("min", None), sel.get("max", None))})
|
||||||
smask &= prop >= smin
|
|
||||||
if smax is not None:
|
|
||||||
smask &= prop < smax
|
|
||||||
|
|
||||||
return pos[smask, ...]
|
# Now the secondary selection bounds. If needed transfrom the secondary
|
||||||
|
# property before applying the bounds.
|
||||||
|
if "secondary" in selection:
|
||||||
|
sel = selection["secondary"]
|
||||||
|
sname = None
|
||||||
|
xs = sel["names"] if isinstance(sel["names"], list) else [sel["names"]]
|
||||||
|
for _name in xs:
|
||||||
|
if _name in cat.keys:
|
||||||
|
sname = _name
|
||||||
|
if sname is None:
|
||||||
|
raise KeyError(f"Invalid names `{sel['name']}`.")
|
||||||
|
|
||||||
|
if sel.get("toperm", False):
|
||||||
|
cat[sname] = numpy.random.permutation(cat[sname])
|
||||||
|
|
||||||
|
if sel.get("marked", False):
|
||||||
|
cat[sname] = csiborgtools.clustering.normalised_marks(
|
||||||
|
cat[pname], cat[sname], nbins=config["nbins_marks"])
|
||||||
|
cat.apply_bounds({sname: (sel.get("min", None), sel.get("max", None))})
|
||||||
|
return cat
|
||||||
|
|
||||||
|
|
||||||
def do_auto(run, cat, ic):
|
def do_auto(run, nsim, nobs=None):
|
||||||
"""Calculate the kNN-CDF single catalgoue autocorrelation."""
|
"""Calculate the kNN-CDF single catalgoue autocorrelation."""
|
||||||
_config = config.get(run, None)
|
_config = config.get(run, None)
|
||||||
if _config is None:
|
if _config is None:
|
||||||
|
@ -102,22 +120,20 @@ def do_auto(run, cat, ic):
|
||||||
return
|
return
|
||||||
|
|
||||||
rvs_gen = csiborgtools.clustering.RVSinsphere(Rmax)
|
rvs_gen = csiborgtools.clustering.RVSinsphere(Rmax)
|
||||||
pos = read_single(_config, cat)
|
cat = read_single(nsim, _config, nobs=nobs)
|
||||||
knn = NearestNeighbors()
|
knn = cat.knn(in_initial=False)
|
||||||
knn.fit(pos)
|
|
||||||
rs, cdf = knncdf(
|
rs, cdf = knncdf(
|
||||||
knn, rvs_gen=rvs_gen, nneighbours=config["nneighbours"],
|
knn, rvs_gen=rvs_gen, nneighbours=config["nneighbours"],
|
||||||
rmin=config["rmin"], rmax=config["rmax"],
|
rmin=config["rmin"], rmax=config["rmax"],
|
||||||
nsamples=int(config["nsamples"]), neval=int(config["neval"]),
|
nsamples=int(config["nsamples"]), neval=int(config["neval"]),
|
||||||
batch_size=int(config["batch_size"]), random_state=config["seed"])
|
batch_size=int(config["batch_size"]), random_state=config["seed"])
|
||||||
|
|
||||||
joblib.dump(
|
fout = paths.knnauto_path(args.simname, run, nsim, nobs)
|
||||||
{"rs": rs, "cdf": cdf, "ndensity": pos.shape[0] / totvol},
|
print(f"Saving output to `{fout}`.")
|
||||||
paths.knnauto_path(run, ic),
|
joblib.dump({"rs": rs, "cdf": cdf, "ndensity": len(cat) / totvol}, fout)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def do_cross_rand(run, cat, ic):
|
def do_cross_rand(run, nsim, nobs=None):
|
||||||
"""Calculate the kNN-CDF cross catalogue random correlation."""
|
"""Calculate the kNN-CDF cross catalogue random correlation."""
|
||||||
_config = config.get(run, None)
|
_config = config.get(run, None)
|
||||||
if _config is None:
|
if _config is None:
|
||||||
|
@ -125,31 +141,32 @@ def do_cross_rand(run, cat, ic):
|
||||||
return
|
return
|
||||||
|
|
||||||
rvs_gen = csiborgtools.clustering.RVSinsphere(Rmax)
|
rvs_gen = csiborgtools.clustering.RVSinsphere(Rmax)
|
||||||
knn1, knn2 = NearestNeighbors(), NearestNeighbors()
|
cat = read_single(nsim, _config)
|
||||||
|
knn1 = cat.knn(in_initial=False)
|
||||||
|
|
||||||
pos1 = read_single(_config, cat)
|
knn2 = NearestNeighbors()
|
||||||
knn1.fit(pos1)
|
pos2 = rvs_gen(len(cat).shape[0])
|
||||||
|
|
||||||
pos2 = rvs_gen(pos1.shape[0])
|
|
||||||
knn2.fit(pos2)
|
knn2.fit(pos2)
|
||||||
|
|
||||||
rs, cdf0, cdf1, joint_cdf = knncdf.joint(
|
rs, cdf0, cdf1, joint_cdf = knncdf.joint(
|
||||||
knn1, knn2, rvs_gen=rvs_gen, nneighbours=int(config["nneighbours"]),
|
knn1, knn2, rvs_gen=rvs_gen, nneighbours=int(config["nneighbours"]),
|
||||||
rmin=config["rmin"], rmax=config["rmax"],
|
rmin=config["rmin"], rmax=config["rmax"],
|
||||||
nsamples=int(config["nsamples"]), neval=int(config["neval"]),
|
nsamples=int(config["nsamples"]), neval=int(config["neval"]),
|
||||||
batch_size=int(config["batch_size"]), random_state=config["seed"],
|
batch_size=int(config["batch_size"]), random_state=config["seed"])
|
||||||
)
|
|
||||||
corr = knncdf.joint_to_corr(cdf0, cdf1, joint_cdf)
|
corr = knncdf.joint_to_corr(cdf0, cdf1, joint_cdf)
|
||||||
joblib.dump({"rs": rs, "corr": corr}, paths.knnauto_path(run, ic))
|
fout = paths.knnauto_path(args.simname, run, nsim, nobs)
|
||||||
|
print(f"Saving output to `{fout}`.")
|
||||||
|
joblib.dump({"rs": rs, "corr": corr}, fout)
|
||||||
|
|
||||||
|
|
||||||
def do_runs(ic):
|
def do_runs(nsim):
|
||||||
cat = csiborgtools.read.ClumpsCatalogue(ic, paths, maxdist=Rmax)
|
|
||||||
for run in args.runs:
|
for run in args.runs:
|
||||||
if "random" in run:
|
iters = range(27) if args.simname == "quijote" else [None]
|
||||||
do_cross_rand(run, cat, ic)
|
for nobs in iters:
|
||||||
else:
|
if "random" in run:
|
||||||
do_auto(run, cat, ic)
|
do_cross_rand(run, nsim, nobs)
|
||||||
|
else:
|
||||||
|
do_auto(run, nsim, nobs)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
rmin: 0.1
|
rmin: 0.1
|
||||||
rmax: 100
|
rmax: 100
|
||||||
nneighbours: 64
|
nneighbours: 8
|
||||||
nsamples: 1.e+7
|
nsamples: 1.e+5
|
||||||
batch_size: 1.e+6
|
batch_size: 5.e+4
|
||||||
neval: 10000
|
neval: 10000
|
||||||
seed: 42
|
seed: 42
|
||||||
nbins_marks: 10
|
nbins_marks: 10
|
||||||
|
@ -15,19 +15,25 @@ nbins_marks: 10
|
||||||
|
|
||||||
"mass001":
|
"mass001":
|
||||||
primary:
|
primary:
|
||||||
name: totpartmass
|
name:
|
||||||
|
- totpartmass,
|
||||||
|
- group_mass
|
||||||
min: 1.e+12
|
min: 1.e+12
|
||||||
max: 1.e+13
|
max: 1.e+13
|
||||||
|
|
||||||
"mass002":
|
"mass002":
|
||||||
primary:
|
primary:
|
||||||
name: totpartmass
|
name:
|
||||||
|
- totpartmass,
|
||||||
|
- group_mass
|
||||||
min: 1.e+13
|
min: 1.e+13
|
||||||
max: 1.e+14
|
max: 1.e+14
|
||||||
|
|
||||||
"mass003":
|
"mass003":
|
||||||
primary:
|
primary:
|
||||||
name: totpartmass
|
name:
|
||||||
|
- totpartmass,
|
||||||
|
- group_mass
|
||||||
min: 1.e+14
|
min: 1.e+14
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,15 @@
|
||||||
# You should have received a copy of the GNU General Public License along
|
# You should have received a copy of the GNU General Public License along
|
||||||
# with this program; if not, write to the Free Software Foundation, Inc.,
|
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||||
"""A script to calculate the KNN-CDF for a set of CSiBORG halo catalogues."""
|
"""
|
||||||
|
A script to calculate the KNN-CDF for a set of CSiBORG halo catalogues.
|
||||||
|
|
||||||
|
TODO:
|
||||||
|
- [ ] Update catalogue readers.
|
||||||
|
- [ ] Update paths.
|
||||||
|
- [ ] Update to cross-correlate different mass populations from different
|
||||||
|
simulations.
|
||||||
|
"""
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
|
@ -43,6 +51,7 @@ nproc = comm.Get_size()
|
||||||
|
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--runs", type=str, nargs="+")
|
parser.add_argument("--runs", type=str, nargs="+")
|
||||||
|
parser.add_argument("--simname", type=str, choices=["csiborg", "quijote"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
with open("../scripts/knn_cross.yml", "r") as file:
|
with open("../scripts/knn_cross.yml", "r") as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
|
@ -50,7 +59,7 @@ with open("../scripts/knn_cross.yml", "r") as file:
|
||||||
Rmax = 155 / 0.705 # Mpc (h = 0.705) high resolution region radius
|
Rmax = 155 / 0.705 # Mpc (h = 0.705) high resolution region radius
|
||||||
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
ics = paths.get_ics()
|
ics = paths.get_ics()
|
||||||
knncdf = csiborgtools.clustering.kNN_CDF()
|
knncdf = csiborgtools.clustering.kNN_1DCDF()
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Analysis #
|
# Analysis #
|
||||||
|
@ -100,13 +109,13 @@ def do_cross(run, ics):
|
||||||
)
|
)
|
||||||
|
|
||||||
corr = knncdf.joint_to_corr(cdf0, cdf1, joint_cdf)
|
corr = knncdf.joint_to_corr(cdf0, cdf1, joint_cdf)
|
||||||
joblib.dump({"rs": rs, "corr": corr}, paths.knncross_path(run, ics))
|
fout = paths.knncross_path(args.simname, run, ics)
|
||||||
|
joblib.dump({"rs": rs, "corr": corr}, fout)
|
||||||
|
|
||||||
|
|
||||||
def do_runs(ics):
|
def do_runs(nsims):
|
||||||
print(ics)
|
|
||||||
for run in args.runs:
|
for run in args.runs:
|
||||||
do_cross(run, ics)
|
do_cross(run, nsims)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
|
@ -12,7 +12,9 @@
|
||||||
# You should have received a copy of the GNU General Public License along
|
# You should have received a copy of the GNU General Public License along
|
||||||
# with this program; if not, write to the Free Software Foundation, Inc.,
|
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||||
"""A script to calculate the auto-2PCF of CSiBORG catalogues."""
|
"""
|
||||||
|
A script to calculate the auto-2PCF of CSiBORG catalogues.
|
||||||
|
"""
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -22,8 +24,11 @@ import joblib
|
||||||
import numpy
|
import numpy
|
||||||
import yaml
|
import yaml
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
|
|
||||||
from taskmaster import master_process, worker_process
|
from taskmaster import master_process, worker_process
|
||||||
|
|
||||||
|
from .cluster_knn_auto import read_single
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import csiborgtools
|
import csiborgtools
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
@ -42,57 +47,31 @@ nproc = comm.Get_size()
|
||||||
|
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--runs", type=str, nargs="+")
|
parser.add_argument("--runs", type=str, nargs="+")
|
||||||
|
parser.add_argument("--ics", type=int, nargs="+", default=None,
|
||||||
|
help="IC realisations. If `-1` processes all simulations.")
|
||||||
|
parser.add_argument("--simname", type=str, choices=["csiborg", "quijote"])
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
with open("../scripts/tpcf_auto.yml", "r") as file:
|
with open("../scripts/tpcf_auto.yml", "r") as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
Rmax = 155 / 0.705 # Mpc (h = 0.705) high resolution region radius
|
Rmax = 155 / 0.705 # Mpc (h = 0.705) high resolution region radius
|
||||||
paths = csiborgtools.read.Paths()
|
paths = csiborgtools.read.Paths()
|
||||||
ics = paths.get_ics()
|
|
||||||
tpcf = csiborgtools.clustering.Mock2PCF()
|
tpcf = csiborgtools.clustering.Mock2PCF()
|
||||||
|
|
||||||
|
if args.ics is None or args.ics[0] == -1:
|
||||||
|
if args.simname == "csiborg":
|
||||||
|
ics = paths.get_ics()
|
||||||
|
else:
|
||||||
|
ics = paths.get_quijote_ics()
|
||||||
|
else:
|
||||||
|
ics = args.ics
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Analysis #
|
# Analysis #
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
def read_single(selection, cat):
|
def do_auto(run, nsim):
|
||||||
"""Positions for single catalogue auto-correlation."""
|
|
||||||
mmask = numpy.ones(len(cat), dtype=bool)
|
|
||||||
pos = cat.positions(False)
|
|
||||||
# Primary selection
|
|
||||||
psel = selection["primary"]
|
|
||||||
pmin, pmax = psel.get("min", None), psel.get("max", None)
|
|
||||||
if pmin is not None:
|
|
||||||
mmask &= cat[psel["name"]] >= pmin
|
|
||||||
if pmax is not None:
|
|
||||||
mmask &= cat[psel["name"]] < pmax
|
|
||||||
pos = pos[mmask, ...]
|
|
||||||
|
|
||||||
# Secondary selection
|
|
||||||
if "secondary" not in selection:
|
|
||||||
return pos
|
|
||||||
smask = numpy.ones(pos.shape[0], dtype=bool)
|
|
||||||
ssel = selection["secondary"]
|
|
||||||
smin, smax = ssel.get("min", None), ssel.get("max", None)
|
|
||||||
prop = cat[ssel["name"]][mmask]
|
|
||||||
if ssel.get("toperm", False):
|
|
||||||
prop = numpy.random.permutation(prop)
|
|
||||||
if ssel.get("marked", True):
|
|
||||||
x = cat[psel["name"]][mmask]
|
|
||||||
prop = csiborgtools.clustering.normalised_marks(
|
|
||||||
x, prop, nbins=config["nbins_marks"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if smin is not None:
|
|
||||||
smask &= prop >= smin
|
|
||||||
if smax is not None:
|
|
||||||
smask &= prop < smax
|
|
||||||
|
|
||||||
return pos[smask, ...]
|
|
||||||
|
|
||||||
|
|
||||||
def do_auto(run, cat, ic):
|
|
||||||
_config = config.get(run, None)
|
_config = config.get(run, None)
|
||||||
if _config is None:
|
if _config is None:
|
||||||
warn("No configuration for run {}.".format(run), stacklevel=1)
|
warn("No configuration for run {}.".format(run), stacklevel=1)
|
||||||
|
@ -104,17 +83,18 @@ def do_auto(run, cat, ic):
|
||||||
numpy.log10(config["rpmax"]),
|
numpy.log10(config["rpmax"]),
|
||||||
config["nrpbins"] + 1,
|
config["nrpbins"] + 1,
|
||||||
)
|
)
|
||||||
pos = read_single(_config, cat)
|
cat = read_single(nsim, _config)
|
||||||
|
pos = cat.position(in_initial=False, cartesian=True)
|
||||||
nrandom = int(config["randmult"] * pos.shape[0])
|
nrandom = int(config["randmult"] * pos.shape[0])
|
||||||
rp, wp = tpcf(pos, rvs_gen, nrandom, bins)
|
rp, wp = tpcf(pos, rvs_gen, nrandom, bins)
|
||||||
|
|
||||||
joblib.dump({"rp": rp, "wp": wp}, paths.tpcfauto_path(run, ic))
|
fout = paths.tpcfauto_path(args.simname, run, nsim)
|
||||||
|
joblib.dump({"rp": rp, "wp": wp}, fout)
|
||||||
|
|
||||||
|
|
||||||
def do_runs(ic):
|
def do_runs(nsim):
|
||||||
cat = csiborgtools.read.ClumpsCatalogue(ic, paths, maxdist=Rmax)
|
|
||||||
for run in args.runs:
|
for run in args.runs:
|
||||||
do_auto(run, cat, ic)
|
do_auto(run, nsim)
|
||||||
|
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
|
@ -65,7 +65,7 @@ for i, nsim in enumerate(nsims):
|
||||||
particles = f["particles"]
|
particles = f["particles"]
|
||||||
clump_map = f["clumpmap"]
|
clump_map = f["clumpmap"]
|
||||||
clid2map = {clid: i for i, clid in enumerate(clump_map[:, 0])}
|
clid2map = {clid: i for i, clid in enumerate(clump_map[:, 0])}
|
||||||
clumps_cat = csiborgtools.read.ClumpsCatalogue(nsim, paths, rawdata=True,
|
clumps_cat = csiborgtools.read.ClumpsCatalogue(nsim, paths, rawdata=True,
|
||||||
load_fitted=False)
|
load_fitted=False)
|
||||||
ismain = clumps_cat.ismain
|
ismain = clumps_cat.ismain
|
||||||
ntasks = len(clumps_cat)
|
ntasks = len(clumps_cat)
|
||||||
|
|
|
@ -39,12 +39,11 @@ def pair_match(nsim0, nsimx, sigma, smoothen, verbose):
|
||||||
|
|
||||||
# Load the raw catalogues (i.e. no selection) including the initial CM
|
# Load the raw catalogues (i.e. no selection) including the initial CM
|
||||||
# positions and the particle archives.
|
# positions and the particle archives.
|
||||||
cat0 = HaloCatalogue(nsim0, paths, load_initial=True,
|
bounds = {"totpartmass": (1e12, None)}
|
||||||
minmass=("totpartmass", 1e12), with_lagpatch=True,
|
cat0 = HaloCatalogue(nsim0, paths, load_initial=True, bounds=bounds,
|
||||||
load_clumps_cat=True)
|
with_lagpatch=True, load_clumps_cat=True)
|
||||||
catx = HaloCatalogue(nsimx, paths, load_initial=True,
|
catx = HaloCatalogue(nsimx, paths, load_initial=True, bounds=bounds,
|
||||||
minmass=("totpartmass", 1e12), with_lagpatch=True,
|
with_lagpatch=True, load_clumps_cat=True)
|
||||||
load_clumps_cat=True)
|
|
||||||
|
|
||||||
clumpmap0 = read_h5(paths.particles_path(nsim0))["clumpmap"]
|
clumpmap0 = read_h5(paths.particles_path(nsim0))["clumpmap"]
|
||||||
parts0 = read_h5(paths.initmatch_path(nsim0, "particles"))["particles"]
|
parts0 = read_h5(paths.initmatch_path(nsim0, "particles"))["particles"]
|
||||||
|
|
154
scripts_plots/overlap.py
Normal file
154
scripts_plots/overlap.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
# Copyright (C) 2023 Richard Stiskalek
|
||||||
|
# This program is free software; you can redistribute it and/or modify it
|
||||||
|
# under the terms of the GNU General Public License as published by the
|
||||||
|
# Free Software Foundation; either version 3 of the License, or (at your
|
||||||
|
# option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful, but
|
||||||
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
|
||||||
|
# Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License along
|
||||||
|
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||||
|
|
||||||
|
from os.path import join
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
import scienceplots # noqa
|
||||||
|
import utils
|
||||||
|
from cache_to_disk import cache_to_disk, delete_disk_caches_for_function
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
import csiborgtools
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
import sys
|
||||||
|
sys.path.append("../")
|
||||||
|
import csiborgtools
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Probability of matching a reference simulation halo #
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def open_cat(nsim):
|
||||||
|
"""
|
||||||
|
Open a CSiBORG halo catalogue.
|
||||||
|
"""
|
||||||
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
|
bounds = {"totpartmass": (1e12, None)}
|
||||||
|
return csiborgtools.read.HaloCatalogue(nsim, paths, bounds=bounds)
|
||||||
|
|
||||||
|
|
||||||
|
@cache_to_disk(7)
|
||||||
|
def get_overlap(nsim0):
|
||||||
|
"""
|
||||||
|
Calculate the summed overlap and probability of no match for a single
|
||||||
|
reference simulation.
|
||||||
|
"""
|
||||||
|
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)
|
||||||
|
nsimxs = csiborgtools.read.get_cross_sims(nsim0, paths, smoothed=True)
|
||||||
|
cat0 = open_cat(nsim0)
|
||||||
|
|
||||||
|
catxs = []
|
||||||
|
for nsimx in tqdm(nsimxs):
|
||||||
|
catxs.append(open_cat(nsimx))
|
||||||
|
|
||||||
|
reader = csiborgtools.read.NPairsOverlap(cat0, catxs, paths)
|
||||||
|
x = reader.cat0("totpartmass")
|
||||||
|
summed_overlap = reader.summed_overlap(True)
|
||||||
|
prob_nomatch = reader.prob_nomatch(True)
|
||||||
|
return x, summed_overlap, prob_nomatch
|
||||||
|
|
||||||
|
|
||||||
|
def plot_summed_overlap(nsim0):
|
||||||
|
"""
|
||||||
|
Plot the summed overlap and probability of no matching for a single
|
||||||
|
reference simulation as a function of the reference halo mass.
|
||||||
|
"""
|
||||||
|
x, summed_overlap, prob_nomatch = get_overlap(nsim0)
|
||||||
|
|
||||||
|
mean_overlap = numpy.mean(summed_overlap, axis=1)
|
||||||
|
std_overlap = numpy.std(summed_overlap, axis=1)
|
||||||
|
|
||||||
|
mean_prob_nomatch = numpy.mean(prob_nomatch, axis=1)
|
||||||
|
# std_prob_nomatch = numpy.std(prob_nomatch, axis=1)
|
||||||
|
|
||||||
|
mask = mean_overlap > 0
|
||||||
|
x = x[mask]
|
||||||
|
mean_overlap = mean_overlap[mask]
|
||||||
|
std_overlap = std_overlap[mask]
|
||||||
|
mean_prob_nomatch = mean_prob_nomatch[mask]
|
||||||
|
|
||||||
|
# Mean summed overlap
|
||||||
|
with plt.style.context(utils.mplstyle):
|
||||||
|
plt.figure()
|
||||||
|
plt.hexbin(x, mean_overlap, mincnt=1, xscale="log", bins="log",
|
||||||
|
gridsize=50)
|
||||||
|
plt.colorbar(label="Counts in bins")
|
||||||
|
plt.xlabel(r"$M_{\rm tot} / M_\odot$")
|
||||||
|
plt.ylabel(r"$\langle \mathcal{O}_{a}^{\mathcal{A} \mathcal{B}} \rangle_{\mathcal{B}}$") # noqa
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
for ext in ["png", "pdf"]:
|
||||||
|
fout = join(utils.fout, f"overlap_mean_{nsim0}.{ext}")
|
||||||
|
print(f"Saving to `{fout}`.")
|
||||||
|
plt.savefig(fout, dpi=utils.dpi, bbox_inches="tight")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# Std summed overlap
|
||||||
|
with plt.style.context(utils.mplstyle):
|
||||||
|
plt.figure()
|
||||||
|
plt.hexbin(x, std_overlap, mincnt=1, xscale="log", bins="log",
|
||||||
|
gridsize=50)
|
||||||
|
plt.colorbar(label="Counts in bins")
|
||||||
|
plt.xlabel(r"$M_{\rm tot} / M_\odot$")
|
||||||
|
plt.ylabel(r"$\delta \left( \mathcal{O}_{a}^{\mathcal{A} \mathcal{B}} \right)_{\mathcal{B}}$") # noqa
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
for ext in ["png", "pdf"]:
|
||||||
|
fout = join(utils.fout, f"overlap_std_{nsim0}.{ext}")
|
||||||
|
print(f"Saving to `{fout}`.")
|
||||||
|
plt.savefig(fout, dpi=utils.dpi, bbox_inches="tight")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# 1 - mean summed overlap vs mean prob nomatch
|
||||||
|
with plt.style.context(utils.mplstyle):
|
||||||
|
plt.figure()
|
||||||
|
plt.scatter(1 - mean_overlap, mean_prob_nomatch, c=numpy.log10(x), s=2,
|
||||||
|
rasterized=True)
|
||||||
|
plt.colorbar(label=r"$\log_{10} M_{\rm halo} / M_\odot$")
|
||||||
|
|
||||||
|
t = numpy.linspace(0.3, 1, 100)
|
||||||
|
plt.plot(t, t, color="red", linestyle="--")
|
||||||
|
|
||||||
|
plt.xlabel(r"$1 - \langle \mathcal{O}_a^{\mathcal{A} \mathcal{B}} \rangle_{\mathcal{B}}$") # noqa
|
||||||
|
plt.ylabel(r"$\langle \eta_a^{\mathcal{A} \mathcal{B}} \rangle_{\mathcal{B}}$") # noqa
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
for ext in ["png", "pdf"]:
|
||||||
|
fout = join(utils.fout, f"overlap_vs_prob_nomatch_{nsim0}.{ext}")
|
||||||
|
print(f"Saving to `{fout}`.")
|
||||||
|
plt.savefig(fout, dpi=utils.dpi, bbox_inches="tight")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('-c', '--clean', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
cached_funcs = ["get_overlap"]
|
||||||
|
if args.clean:
|
||||||
|
for func in cached_funcs:
|
||||||
|
print(f"Cleaning cache for function {func}.")
|
||||||
|
delete_disk_caches_for_function(func)
|
||||||
|
|
||||||
|
for ic in [7444, 8812, 9700]:
|
||||||
|
plot_summed_overlap(ic)
|
18
scripts_plots/utils.py
Normal file
18
scripts_plots/utils.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
# Copyright (C) 2023 Richard Stiskalek
|
||||||
|
# This program is free software; you can redistribute it and/or modify it
|
||||||
|
# under the terms of the GNU General Public License as published by the
|
||||||
|
# Free Software Foundation; either version 3 of the License, or (at your
|
||||||
|
# option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful, but
|
||||||
|
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
|
||||||
|
# Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License along
|
||||||
|
# with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
|
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
|
||||||
|
|
||||||
|
dpi = 450
|
||||||
|
fout = "../plots/"
|
||||||
|
mplstyle = ["notebook"]
|
Loading…
Reference in a new issue