diff --git a/csiborgtools/match/knn.py b/csiborgtools/match/knn.py index 4a1550e..f6211d5 100644 --- a/csiborgtools/match/knn.py +++ b/csiborgtools/match/knn.py @@ -15,7 +15,6 @@ """ kNN-CDF calculation """ -from gc import collect import numpy from scipy.interpolate import interp1d from scipy.stats import binned_statistic @@ -125,6 +124,55 @@ class kNN_CDF: cdf[cdf > 0.5] = 1 - cdf[cdf > 0.5] return cdf + def brute_cdf(self, knn, nneighbours, Rmax, nsamples, rmin, rmax, neval, + random_state=42, dtype=numpy.float32): + """ + Calculate the CDF for a kNN of CSiBORG halo catalogues without batch + sizing. This can become memory intense for large numbers of randoms + and, therefore, is only for testing purposes. + + Parameters + ---------- + knns : `sklearn.neighbors.NearestNeighbors` + kNN of CSiBORG halo catalogues. + neighbours : int + Maximum number of neighbours to use for the kNN-CDF calculation. + Rmax : float + Maximum radius of the sphere in which to sample random points for + the knn-CDF calculation. This should match the CSiBORG catalogues. + nsamples : int + Number of random points to sample for the knn-CDF calculation. + rmin : float + Minimum distance to evaluate the CDF. + rmax : float + Maximum distance to evaluate the CDF. + neval : int + Number of points to evaluate the CDF. + random_state : int, optional + Random state for the random number generator. + dtype : numpy dtype, optional + Calculation data type. By default `numpy.float32`. + + Returns + ------- + rs : 1-dimensional array + Distances at which the CDF is evaluated. + cdfs : 2-dimensional array + CDFs evaluated at `rs`. + """ + rand = self.rvs_in_sphere(nsamples, Rmax, random_state=random_state) + + dist, __ = knn.kneighbors(rand, nneighbours) + dist = dist.astype(dtype) + + cdf = [None] * nneighbours + for j in range(nneighbours): + rs, cdf[j] = self.cdf_from_samples(dist[:, j], rmin=rmin, + rmax=rmax, neval=neval) + + cdf = numpy.asanyarray(cdf) + return rs, cdf + def __call__(self, *knns, nneighbours, Rmax, nsamples, rmin, rmax, neval, batch_size=None, verbose=True, random_state=42, left_nan=True, right_nan=True, dtype=numpy.float32):