diff --git a/csiborgtools/match/knn.py b/csiborgtools/match/knn.py index 99b81f1..aac55c7 100644 --- a/csiborgtools/match/knn.py +++ b/csiborgtools/match/knn.py @@ -15,9 +15,9 @@ """ kNN-CDF calculation """ -from gc import collect import numpy from scipy.interpolate import interp1d +from scipy.stats import binned_statistic from tqdm import tqdm @@ -124,8 +124,58 @@ class kNN_CDF: cdf[cdf > 0.5] = 1 - cdf[cdf > 0.5] return cdf + def brute_cdf(self, knn, nneighbours, Rmax, nsamples, rmin, rmax, neval, + random_state=42, dtype=numpy.float32): + """ + Calculate the CDF for a kNN of CSiBORG halo catalogues without batch + sizing. This can become memory intense for large numbers of randoms + and, therefore, is only for testing purposes. + + Parameters + ---------- + knns : `sklearn.neighbors.NearestNeighbors` + kNN of CSiBORG halo catalogues. + neighbours : int + Maximum number of neighbours to use for the kNN-CDF calculation. + Rmax : float + Maximum radius of the sphere in which to sample random points for + the knn-CDF calculation. This should match the CSiBORG catalogues. + nsamples : int + Number of random points to sample for the knn-CDF calculation. + rmin : float + Minimum distance to evaluate the CDF. + rmax : float + Maximum distance to evaluate the CDF. + neval : int + Number of points to evaluate the CDF. + random_state : int, optional + Random state for the random number generator. + dtype : numpy dtype, optional + Calculation data type. By default `numpy.float32`. + + Returns + ------- + rs : 1-dimensional array + Distances at which the CDF is evaluated. + cdfs : 2-dimensional array + CDFs evaluated at `rs`. + """ + rand = self.rvs_in_sphere(nsamples, Rmax, random_state=random_state) + + dist, __ = knn.kneighbors(rand, nneighbours) + dist = dist.astype(dtype) + + cdf = [None] * nneighbours + for j in range(nneighbours): + rs, cdf[j] = self.cdf_from_samples(dist[:, j], rmin=rmin, + rmax=rmax, neval=neval) + + cdf = numpy.asanyarray(cdf) + return rs, cdf + def __call__(self, *knns, nneighbours, Rmax, nsamples, rmin, rmax, neval, - verbose=True, random_state=42, dtype=numpy.float32): + batch_size=None, verbose=True, random_state=42, + left_nan=True, right_nan=True, dtype=numpy.float32): """ Calculate the CDF for a set of kNNs of CSiBORG halo catalogues. @@ -146,10 +196,20 @@ class kNN_CDF: Maximum distance to evaluate the CDF. neval : int Number of points to evaluate the CDF. + batch_size : int, optional + Number of random points to sample in each batch. By default equal + to `nsamples`, however recommeded to be smaller to avoid requesting + too much memory, verbose : bool, optional Verbosity flag. random_state : int, optional Random state for the random number generator. + left_nan : bool, optional + Whether to set values where the CDF is 0 to `numpy.nan`. By + default `True`. + right_nan : bool, optional + Whether to set values where the CDF is 1 to `numpy.nan` after its + first occurence to 1. By default `True`. dtype : numpy dtype, optional Calculation data type. By default `numpy.float32`. @@ -160,22 +220,40 @@ class kNN_CDF: cdfs : 2 or 3-dimensional array CDFs evaluated at `rs`. """ - rand = self.rvs_in_sphere(nsamples, Rmax, random_state=random_state) + batch_size = nsamples if batch_size is None else batch_size + assert nsamples >= batch_size + nbatches = nsamples // batch_size # Number of batches - cdfs = [None] * len(knns) + # Preallocate the bins and the CDF array + bins = numpy.logspace(numpy.log10(rmin), numpy.log10(rmax), neval) + cdfs = numpy.zeros((len(knns), nneighbours, neval - 1), dtype=dtype) for i, knn in enumerate(tqdm(knns) if verbose else knns): - dist, _indxs = knn.kneighbors(rand, nneighbours) - dist = dist.astype(dtype) - del _indxs - collect() + # Loop over batches. This is to avoid generating large mocks + # requiring a lot of memory. Add counts to the CDF array + for j in range(nbatches): + rand = self.rvs_in_sphere(batch_size, Rmax, + random_state=random_state + j) + dist, __ = knn.kneighbors(rand, nneighbours) + for k in range(nneighbours): # Count for each neighbour + _counts, __, __ = binned_statistic( + dist[:, k], dist[:, k], bins=bins, statistic="count", + range=(rmin, rmax)) + cdfs[i, k, :] += _counts + rs = (bins[1:] + bins[:-1]) / 2 # Bin centers + cdfs = numpy.cumsum(cdfs, axis=-1) # Cumulative sum, i.e. the CDF + for i in range(len(knns)): + for k in range(nneighbours): + cdfs[i, k, :] /= cdfs[i, k, -1] + # Set to NaN values after the first point where the CDF is 1 + if right_nan: + ns = numpy.where(cdfs[i, k, :] == 1.)[0] + if ns.size > 1: + cdfs[i, k, ns[1]:] = numpy.nan - cdf = [None] * nneighbours - for j in range(nneighbours): - rs, cdf[j] = self.cdf_from_samples( - dist[:, j], rmin=rmin, rmax=rmax, neval=neval) - cdfs[i] = cdf + # Set to NaN values where the CDF is 0 + if left_nan: + cdfs[cdfs == 0] = numpy.nan - cdfs = numpy.asanyarray(cdfs) cdfs = cdfs[0, ...] if len(knns) == 1 else cdfs return rs, cdfs diff --git a/notebooks/knn.ipynb b/notebooks/knn.ipynb index 8ebdb6d..59a005c 100644 --- a/notebooks/knn.ipynb +++ b/notebooks/knn.ipynb @@ -2,12 +2,12 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "id": "5a38ed25", "metadata": { "ExecuteTime": { - "end_time": "2023-03-31T17:09:12.165480Z", - "start_time": "2023-03-31T17:09:12.116708Z" + "end_time": "2023-04-01T06:20:33.195162Z", + "start_time": "2023-04-01T06:20:29.474122Z" }, "scrolled": true }, @@ -16,8 +16,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" + "not found\n" ] } ], @@ -44,12 +43,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "4218b673", "metadata": { "ExecuteTime": { - "end_time": "2023-03-31T17:09:13.943312Z", - "start_time": "2023-03-31T17:09:12.167027Z" + "end_time": "2023-04-01T06:20:35.273662Z", + "start_time": "2023-04-01T06:20:33.196875Z" } }, "outputs": [], @@ -59,12 +58,12 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 24, "id": "5ff7a1b6", "metadata": { "ExecuteTime": { - "end_time": "2023-03-31T17:10:18.303240Z", - "start_time": "2023-03-31T17:10:14.674751Z" + "end_time": "2023-04-01T06:55:34.643955Z", + "start_time": "2023-04-01T06:55:28.334204Z" } }, "outputs": [ @@ -72,38 +71,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\r", - " 0%| | 0/1 [00:00