mirror of
https://github.com/Richard-Sti/csiborgtools.git
synced 2024-12-22 17:48:01 +00:00
kNN memory batching (#35)
* Add batch sizing for less memory * Add batch size to submission * Update nb * Add brute KNN * unused variable * Update nb
This commit is contained in:
parent
63ab3548b4
commit
513872ceb6
4 changed files with 188 additions and 64 deletions
|
@ -15,9 +15,9 @@
|
||||||
"""
|
"""
|
||||||
kNN-CDF calculation
|
kNN-CDF calculation
|
||||||
"""
|
"""
|
||||||
from gc import collect
|
|
||||||
import numpy
|
import numpy
|
||||||
from scipy.interpolate import interp1d
|
from scipy.interpolate import interp1d
|
||||||
|
from scipy.stats import binned_statistic
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,8 +124,58 @@ class kNN_CDF:
|
||||||
cdf[cdf > 0.5] = 1 - cdf[cdf > 0.5]
|
cdf[cdf > 0.5] = 1 - cdf[cdf > 0.5]
|
||||||
return cdf
|
return cdf
|
||||||
|
|
||||||
|
def brute_cdf(self, knn, nneighbours, Rmax, nsamples, rmin, rmax, neval,
|
||||||
|
random_state=42, dtype=numpy.float32):
|
||||||
|
"""
|
||||||
|
Calculate the CDF for a kNN of CSiBORG halo catalogues without batch
|
||||||
|
sizing. This can become memory intense for large numbers of randoms
|
||||||
|
and, therefore, is only for testing purposes.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
knns : `sklearn.neighbors.NearestNeighbors`
|
||||||
|
kNN of CSiBORG halo catalogues.
|
||||||
|
neighbours : int
|
||||||
|
Maximum number of neighbours to use for the kNN-CDF calculation.
|
||||||
|
Rmax : float
|
||||||
|
Maximum radius of the sphere in which to sample random points for
|
||||||
|
the knn-CDF calculation. This should match the CSiBORG catalogues.
|
||||||
|
nsamples : int
|
||||||
|
Number of random points to sample for the knn-CDF calculation.
|
||||||
|
rmin : float
|
||||||
|
Minimum distance to evaluate the CDF.
|
||||||
|
rmax : float
|
||||||
|
Maximum distance to evaluate the CDF.
|
||||||
|
neval : int
|
||||||
|
Number of points to evaluate the CDF.
|
||||||
|
random_state : int, optional
|
||||||
|
Random state for the random number generator.
|
||||||
|
dtype : numpy dtype, optional
|
||||||
|
Calculation data type. By default `numpy.float32`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
rs : 1-dimensional array
|
||||||
|
Distances at which the CDF is evaluated.
|
||||||
|
cdfs : 2-dimensional array
|
||||||
|
CDFs evaluated at `rs`.
|
||||||
|
"""
|
||||||
|
rand = self.rvs_in_sphere(nsamples, Rmax, random_state=random_state)
|
||||||
|
|
||||||
|
dist, __ = knn.kneighbors(rand, nneighbours)
|
||||||
|
dist = dist.astype(dtype)
|
||||||
|
|
||||||
|
cdf = [None] * nneighbours
|
||||||
|
for j in range(nneighbours):
|
||||||
|
rs, cdf[j] = self.cdf_from_samples(dist[:, j], rmin=rmin,
|
||||||
|
rmax=rmax, neval=neval)
|
||||||
|
|
||||||
|
cdf = numpy.asanyarray(cdf)
|
||||||
|
return rs, cdf
|
||||||
|
|
||||||
def __call__(self, *knns, nneighbours, Rmax, nsamples, rmin, rmax, neval,
|
def __call__(self, *knns, nneighbours, Rmax, nsamples, rmin, rmax, neval,
|
||||||
verbose=True, random_state=42, dtype=numpy.float32):
|
batch_size=None, verbose=True, random_state=42,
|
||||||
|
left_nan=True, right_nan=True, dtype=numpy.float32):
|
||||||
"""
|
"""
|
||||||
Calculate the CDF for a set of kNNs of CSiBORG halo catalogues.
|
Calculate the CDF for a set of kNNs of CSiBORG halo catalogues.
|
||||||
|
|
||||||
|
@ -146,10 +196,20 @@ class kNN_CDF:
|
||||||
Maximum distance to evaluate the CDF.
|
Maximum distance to evaluate the CDF.
|
||||||
neval : int
|
neval : int
|
||||||
Number of points to evaluate the CDF.
|
Number of points to evaluate the CDF.
|
||||||
|
batch_size : int, optional
|
||||||
|
Number of random points to sample in each batch. By default equal
|
||||||
|
to `nsamples`, however recommeded to be smaller to avoid requesting
|
||||||
|
too much memory,
|
||||||
verbose : bool, optional
|
verbose : bool, optional
|
||||||
Verbosity flag.
|
Verbosity flag.
|
||||||
random_state : int, optional
|
random_state : int, optional
|
||||||
Random state for the random number generator.
|
Random state for the random number generator.
|
||||||
|
left_nan : bool, optional
|
||||||
|
Whether to set values where the CDF is 0 to `numpy.nan`. By
|
||||||
|
default `True`.
|
||||||
|
right_nan : bool, optional
|
||||||
|
Whether to set values where the CDF is 1 to `numpy.nan` after its
|
||||||
|
first occurence to 1. By default `True`.
|
||||||
dtype : numpy dtype, optional
|
dtype : numpy dtype, optional
|
||||||
Calculation data type. By default `numpy.float32`.
|
Calculation data type. By default `numpy.float32`.
|
||||||
|
|
||||||
|
@ -160,22 +220,40 @@ class kNN_CDF:
|
||||||
cdfs : 2 or 3-dimensional array
|
cdfs : 2 or 3-dimensional array
|
||||||
CDFs evaluated at `rs`.
|
CDFs evaluated at `rs`.
|
||||||
"""
|
"""
|
||||||
rand = self.rvs_in_sphere(nsamples, Rmax, random_state=random_state)
|
batch_size = nsamples if batch_size is None else batch_size
|
||||||
|
assert nsamples >= batch_size
|
||||||
|
nbatches = nsamples // batch_size # Number of batches
|
||||||
|
|
||||||
cdfs = [None] * len(knns)
|
# Preallocate the bins and the CDF array
|
||||||
|
bins = numpy.logspace(numpy.log10(rmin), numpy.log10(rmax), neval)
|
||||||
|
cdfs = numpy.zeros((len(knns), nneighbours, neval - 1), dtype=dtype)
|
||||||
for i, knn in enumerate(tqdm(knns) if verbose else knns):
|
for i, knn in enumerate(tqdm(knns) if verbose else knns):
|
||||||
dist, _indxs = knn.kneighbors(rand, nneighbours)
|
# Loop over batches. This is to avoid generating large mocks
|
||||||
dist = dist.astype(dtype)
|
# requiring a lot of memory. Add counts to the CDF array
|
||||||
del _indxs
|
for j in range(nbatches):
|
||||||
collect()
|
rand = self.rvs_in_sphere(batch_size, Rmax,
|
||||||
|
random_state=random_state + j)
|
||||||
|
dist, __ = knn.kneighbors(rand, nneighbours)
|
||||||
|
for k in range(nneighbours): # Count for each neighbour
|
||||||
|
_counts, __, __ = binned_statistic(
|
||||||
|
dist[:, k], dist[:, k], bins=bins, statistic="count",
|
||||||
|
range=(rmin, rmax))
|
||||||
|
cdfs[i, k, :] += _counts
|
||||||
|
|
||||||
|
rs = (bins[1:] + bins[:-1]) / 2 # Bin centers
|
||||||
|
cdfs = numpy.cumsum(cdfs, axis=-1) # Cumulative sum, i.e. the CDF
|
||||||
|
for i in range(len(knns)):
|
||||||
|
for k in range(nneighbours):
|
||||||
|
cdfs[i, k, :] /= cdfs[i, k, -1]
|
||||||
|
# Set to NaN values after the first point where the CDF is 1
|
||||||
|
if right_nan:
|
||||||
|
ns = numpy.where(cdfs[i, k, :] == 1.)[0]
|
||||||
|
if ns.size > 1:
|
||||||
|
cdfs[i, k, ns[1]:] = numpy.nan
|
||||||
|
|
||||||
cdf = [None] * nneighbours
|
# Set to NaN values where the CDF is 0
|
||||||
for j in range(nneighbours):
|
if left_nan:
|
||||||
rs, cdf[j] = self.cdf_from_samples(
|
cdfs[cdfs == 0] = numpy.nan
|
||||||
dist[:, j], rmin=rmin, rmax=rmax, neval=neval)
|
|
||||||
cdfs[i] = cdf
|
|
||||||
|
|
||||||
cdfs = numpy.asanyarray(cdfs)
|
|
||||||
cdfs = cdfs[0, ...] if len(knns) == 1 else cdfs
|
cdfs = cdfs[0, ...] if len(knns) == 1 else cdfs
|
||||||
return rs, cdfs
|
return rs, cdfs
|
||||||
|
|
|
@ -2,12 +2,12 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 1,
|
||||||
"id": "5a38ed25",
|
"id": "5a38ed25",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2023-03-31T17:09:12.165480Z",
|
"end_time": "2023-04-01T06:20:33.195162Z",
|
||||||
"start_time": "2023-03-31T17:09:12.116708Z"
|
"start_time": "2023-04-01T06:20:29.474122Z"
|
||||||
},
|
},
|
||||||
"scrolled": true
|
"scrolled": true
|
||||||
},
|
},
|
||||||
|
@ -16,8 +16,7 @@
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"The autoreload extension is already loaded. To reload it, use:\n",
|
"not found\n"
|
||||||
" %reload_ext autoreload\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -44,12 +43,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 2,
|
||||||
"id": "4218b673",
|
"id": "4218b673",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2023-03-31T17:09:13.943312Z",
|
"end_time": "2023-04-01T06:20:35.273662Z",
|
||||||
"start_time": "2023-03-31T17:09:12.167027Z"
|
"start_time": "2023-04-01T06:20:33.196875Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
@ -59,12 +58,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 24,
|
||||||
"id": "5ff7a1b6",
|
"id": "5ff7a1b6",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"ExecuteTime": {
|
"ExecuteTime": {
|
||||||
"end_time": "2023-03-31T17:10:18.303240Z",
|
"end_time": "2023-04-01T06:55:34.643955Z",
|
||||||
"start_time": "2023-03-31T17:10:14.674751Z"
|
"start_time": "2023-04-01T06:55:28.334204Z"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
@ -72,38 +71,7 @@
|
||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\r",
|
"100%|██████████| 1/1 [00:02<00:00, 2.95s/it]\n"
|
||||||
" 0%| | 0/1 [00:00<?, ?it/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"float32\n",
|
|
||||||
"float32\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"100%|██████████| 1/1 [00:03<00:00, 3.37s/it]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"float32\n",
|
|
||||||
"float32\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -113,18 +81,90 @@
|
||||||
"\n",
|
"\n",
|
||||||
"knncdf = csiborgtools.match.kNN_CDF()\n",
|
"knncdf = csiborgtools.match.kNN_CDF()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"rs, cdfs_high = knncdf(knn, nneighbours=3, Rmax=155 / 0.705, rmin=0.05, rmax=40,\n",
|
"rs, cdf = knncdf(knn, nneighbours=2, Rmax=155 / 0.705, rmin=0.01, rmax=100,\n",
|
||||||
" nsamples=int(1e6), neval=int(1e4), random_state=42)"
|
" nsamples=int(1e6), neval=int(1e4), random_state=42, batch_size=int(1e6))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "08321431",
|
"id": "0d5f3d02",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8b9a8cf0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a1825f00",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-04-01T06:01:29.388586Z",
|
||||||
|
"start_time": "2023-04-01T06:01:29.321025Z"
|
||||||
|
},
|
||||||
|
"scrolled": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"plt.figure()\n",
|
||||||
|
"plt.plot(rs, knncdf.peaked_cdf(cdf[0, :]))\n",
|
||||||
|
"\n",
|
||||||
|
"plt.yscale(\"log\" )\n",
|
||||||
|
"plt.xscale(\"log\")\n",
|
||||||
|
"plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "289549a0",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-03-31T22:55:20.690887Z",
|
||||||
|
"start_time": "2023-03-31T22:55:20.656550Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"mask"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7a8c5202",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-03-31T22:54:52.330633Z",
|
||||||
|
"start_time": "2023-03-31T22:54:52.299548Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "46f54897",
|
||||||
|
"metadata": {
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2023-03-31T22:54:25.138813Z",
|
||||||
|
"start_time": "2023-03-31T22:54:25.105044Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"dist"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
|
|
@ -42,6 +42,7 @@ parser.add_argument("--rmax", type=float)
|
||||||
parser.add_argument("--nneighbours", type=int)
|
parser.add_argument("--nneighbours", type=int)
|
||||||
parser.add_argument("--nsamples", type=int)
|
parser.add_argument("--nsamples", type=int)
|
||||||
parser.add_argument("--neval", type=int)
|
parser.add_argument("--neval", type=int)
|
||||||
|
parser.add_argument("--batch_size", type=int)
|
||||||
parser.add_argument("--seed", type=int, default=42)
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -77,8 +78,8 @@ def do_task(ic):
|
||||||
|
|
||||||
rs, cdf = knncdf(knn, nneighbours=args.nneighbours, Rmax=Rmax,
|
rs, cdf = knncdf(knn, nneighbours=args.nneighbours, Rmax=Rmax,
|
||||||
rmin=args.rmin, rmax=args.rmax, nsamples=args.nsamples,
|
rmin=args.rmin, rmax=args.rmax, nsamples=args.nsamples,
|
||||||
neval=args.neval, random_state=args.seed,
|
neval=args.neval, batch_size=args.batch_size,
|
||||||
verbose=False)
|
random_state=args.seed, verbose=False)
|
||||||
out.update({"cdf_{}".format(i): cdf})
|
out.update({"cdf_{}".format(i): cdf})
|
||||||
|
|
||||||
out.update({"rs": rs, "mass_threshold": mass_threshold})
|
out.update({"rs": rs, "mass_threshold": mass_threshold})
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
nthreads=140
|
nthreads=30
|
||||||
memory=7
|
memory=7
|
||||||
queue="berg"
|
queue="berg"
|
||||||
env="/mnt/zfsusers/rstiskalek/csiborgtools/venv_galomatch/bin/python"
|
env="/mnt/zfsusers/rstiskalek/csiborgtools/venv_galomatch/bin/python"
|
||||||
|
@ -7,9 +7,14 @@ file="run_knn.py"
|
||||||
rmin=0.01
|
rmin=0.01
|
||||||
rmax=100
|
rmax=100
|
||||||
nneighbours=16
|
nneighbours=16
|
||||||
nsamples=10000000
|
nsamples=1000000000
|
||||||
|
batch_size=10000000
|
||||||
neval=10000
|
neval=10000
|
||||||
|
|
||||||
|
# 1000,000,0
|
||||||
|
# 10000000 # 1e7
|
||||||
|
# 1000000000
|
||||||
|
|
||||||
pythoncm="$env $file --rmin $rmin --rmax $rmax --nneighbours $nneighbours --nsamples $nsamples --neval $neval"
|
pythoncm="$env $file --rmin $rmin --rmax $rmax --nneighbours $nneighbours --nsamples $nsamples --neval $neval"
|
||||||
|
|
||||||
# echo $pythoncm
|
# echo $pythoncm
|
||||||
|
|
Loading…
Reference in a new issue