mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 22:28:03 +00:00
Add brute KNN
This commit is contained in:
parent
097b498da6
commit
7610def8a0
1 changed files with 49 additions and 1 deletions
|
@ -15,7 +15,6 @@
|
||||||
"""
|
"""
|
||||||
kNN-CDF calculation
|
kNN-CDF calculation
|
||||||
"""
|
"""
|
||||||
from gc import collect
|
|
||||||
import numpy
|
import numpy
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
from scipy.stats import binned_statistic
|
from scipy.stats import binned_statistic
|
||||||
|
@ -125,6 +124,55 @@ class kNN_CDF:
|
||||||
cdf[cdf > 0.5] = 1 - cdf[cdf > 0.5]
|
cdf[cdf > 0.5] = 1 - cdf[cdf > 0.5]
|
||||||
return cdf
|
return cdf
|
||||||
|
|
||||||
|
def brute_cdf(self, knn, nneighbours, Rmax, nsamples, rmin, rmax, neval,
|
||||||
|
random_state=42, dtype=numpy.float32):
|
||||||
|
"""
|
||||||
|
Calculate the CDF for a kNN of CSiBORG halo catalogues without batch
|
||||||
|
sizing. This can become memory intense for large numbers of randoms
|
||||||
|
and, therefore, is only for testing purposes.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
knns : `sklearn.neighbors.NearestNeighbors`
|
||||||
|
kNN of CSiBORG halo catalogues.
|
||||||
|
neighbours : int
|
||||||
|
Maximum number of neighbours to use for the kNN-CDF calculation.
|
||||||
|
Rmax : float
|
||||||
|
Maximum radius of the sphere in which to sample random points for
|
||||||
|
the knn-CDF calculation. This should match the CSiBORG catalogues.
|
||||||
|
nsamples : int
|
||||||
|
Number of random points to sample for the knn-CDF calculation.
|
||||||
|
rmin : float
|
||||||
|
Minimum distance to evaluate the CDF.
|
||||||
|
rmax : float
|
||||||
|
Maximum distance to evaluate the CDF.
|
||||||
|
neval : int
|
||||||
|
Number of points to evaluate the CDF.
|
||||||
|
random_state : int, optional
|
||||||
|
Random state for the random number generator.
|
||||||
|
dtype : numpy dtype, optional
|
||||||
|
Calculation data type. By default `numpy.float32`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
rs : 1-dimensional array
|
||||||
|
Distances at which the CDF is evaluated.
|
||||||
|
cdfs : 2-dimensional array
|
||||||
|
CDFs evaluated at `rs`.
|
||||||
|
"""
|
||||||
|
rand = self.rvs_in_sphere(nsamples, Rmax, random_state=random_state)
|
||||||
|
|
||||||
|
dist, __ = knn.kneighbors(rand, nneighbours)
|
||||||
|
dist = dist.astype(dtype)
|
||||||
|
|
||||||
|
cdf = [None] * nneighbours
|
||||||
|
for j in range(nneighbours):
|
||||||
|
rs, cdf[j] = self.cdf_from_samples(dist[:, j], rmin=rmin,
|
||||||
|
rmax=rmax, neval=neval)
|
||||||
|
|
||||||
|
cdf = numpy.asanyarray(cdf)
|
||||||
|
return rs, cdf
|
||||||
|
|
||||||
def __call__(self, *knns, nneighbours, Rmax, nsamples, rmin, rmax, neval,
|
def __call__(self, *knns, nneighbours, Rmax, nsamples, rmin, rmax, neval,
|
||||||
batch_size=None, verbose=True, random_state=42,
|
batch_size=None, verbose=True, random_state=42,
|
||||||
left_nan=True, right_nan=True, dtype=numpy.float32):
|
left_nan=True, right_nan=True, dtype=numpy.float32):
|
||||||
|
|
Loading…
Reference in a new issue