mirror of
https://github.com/Richard-Sti/csiborgtools_public.git
synced 2025-05-12 05:38:42 +00:00
Joint kNN-CDF calculation (#36)
* Add joint kNN CDF * add jointKNN calculation * change sub script * Update readme * update sub * Small changes * comments * update nb * Update submisison script
This commit is contained in:
parent
cb67e326c4
commit
522ee709c9
5 changed files with 4569 additions and 57 deletions
|
@ -17,6 +17,7 @@ from os.path import join
|
|||
from argparse import ArgumentParser
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from itertools import combinations
|
||||
from mpi4py import MPI
|
||||
from TaskmasterMPI import master_process, worker_process
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
|
@ -59,7 +60,8 @@ ics = [7444, 7468, 7492, 7516, 7540, 7564, 7588, 7612, 7636, 7660, 7684,
|
|||
9556, 9580, 9604, 9628, 9652, 9676, 9700, 9724, 9748, 9772, 9796,
|
||||
9820, 9844]
|
||||
dumpdir = "/mnt/extraspace/rstiskalek/csiborg/knn"
|
||||
fout = join(dumpdir, "knncdf_{}.p")
|
||||
fout_auto = join(dumpdir, "auto", "knncdf_{}.p")
|
||||
fout_cross = join(dumpdir, "cross", "knncdf_{}_{}.p")
|
||||
|
||||
|
||||
###############################################################################
|
||||
|
@ -68,7 +70,7 @@ fout = join(dumpdir, "knncdf_{}.p")
|
|||
knncdf = csiborgtools.match.kNN_CDF()
|
||||
|
||||
|
||||
def do_task(ic):
|
||||
def do_auto(ic):
|
||||
out = {}
|
||||
cat = csiborgtools.read.HaloCatalogue(ic, max_dist=Rmax)
|
||||
|
||||
|
@ -83,7 +85,39 @@ def do_task(ic):
|
|||
out.update({"cdf_{}".format(i): cdf})
|
||||
|
||||
out.update({"rs": rs, "mass_threshold": mass_threshold})
|
||||
joblib.dump(out, fout.format(ic))
|
||||
joblib.dump(out, fout_auto.format(ic))
|
||||
|
||||
|
||||
def do_cross(ics):
|
||||
out = {}
|
||||
cat1 = csiborgtools.read.HaloCatalogue(ics[0], max_dist=Rmax)
|
||||
cat2 = csiborgtools.read.HaloCatalogue(ics[1], max_dist=Rmax)
|
||||
|
||||
for i, mmin in enumerate(mass_threshold):
|
||||
knn1 = NearestNeighbors()
|
||||
knn1.fit(cat1.positions[cat1["totpartmass"] > mmin, ...])
|
||||
|
||||
knn2 = NearestNeighbors()
|
||||
knn2.fit(cat2.positions[cat2["totpartmass"] > mmin, ...])
|
||||
|
||||
rs, cdf0, cdf1, joint_cdf = knncdf.joint(
|
||||
knn1, knn2, nneighbours=args.nneighbours, Rmax=Rmax,
|
||||
rmin=args.rmin, rmax=args.rmax, nsamples=args.nsamples,
|
||||
neval=args.neval, batch_size=args.batch_size,
|
||||
random_state=args.seed)
|
||||
|
||||
corr = knncdf.joint_to_corr(cdf0, cdf1, joint_cdf)
|
||||
|
||||
out.update({"corr_{}".format(i): corr})
|
||||
|
||||
out.update({"rs": rs, "mass_threshold": mass_threshold})
|
||||
joblib.dump(out, fout_cross.format(*ics))
|
||||
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Autocorrelation calculation #
|
||||
###############################################################################
|
||||
|
||||
|
||||
if nproc > 1:
|
||||
|
@ -91,15 +125,34 @@ if nproc > 1:
|
|||
tasks = deepcopy(ics)
|
||||
master_process(tasks, comm, verbose=True)
|
||||
else:
|
||||
worker_process(do_task, comm, verbose=False)
|
||||
worker_process(do_auto, comm, verbose=False)
|
||||
else:
|
||||
tasks = deepcopy(ics)
|
||||
for task in tasks:
|
||||
print("{}: completing task `{}`.".format(datetime.now(), task))
|
||||
do_task(task)
|
||||
|
||||
|
||||
do_auto(task)
|
||||
comm.Barrier()
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Crosscorrelation calculation #
|
||||
###############################################################################
|
||||
|
||||
|
||||
if nproc > 1:
|
||||
if rank == 0:
|
||||
tasks = list(combinations(ics, 2))
|
||||
master_process(tasks, comm, verbose=True)
|
||||
else:
|
||||
worker_process(do_cross, comm, verbose=False)
|
||||
else:
|
||||
tasks = deepcopy(ics)
|
||||
for task in tasks:
|
||||
print("{}: completing task `{}`.".format(datetime.now(), task))
|
||||
do_cross(task)
|
||||
comm.Barrier()
|
||||
|
||||
|
||||
if rank == 0:
|
||||
print("{}: all finished.".format(datetime.now()))
|
||||
quit() # Force quit the script
|
||||
quit() # Force quit the script
|
||||
|
|
|
@ -1,21 +1,17 @@
|
|||
nthreads=30
|
||||
memory=7
|
||||
queue="berg"
|
||||
nthreads=151
|
||||
memory=4
|
||||
queue="cmb"
|
||||
env="/mnt/zfsusers/rstiskalek/csiborgtools/venv_galomatch/bin/python"
|
||||
file="run_knn.py"
|
||||
|
||||
rmin=0.01
|
||||
rmax=100
|
||||
nneighbours=16
|
||||
nsamples=1000000000
|
||||
batch_size=10000000
|
||||
nneighbours=8
|
||||
nsamples=100000000
|
||||
batch_size=1000000
|
||||
neval=10000
|
||||
|
||||
# 1000,000,0
|
||||
# 10000000 # 1e7
|
||||
# 1000000000
|
||||
|
||||
pythoncm="$env $file --rmin $rmin --rmax $rmax --nneighbours $nneighbours --nsamples $nsamples --neval $neval"
|
||||
pythoncm="$env $file --rmin $rmin --rmax $rmax --nneighbours $nneighbours --nsamples $nsamples --batch_size $batch_size --neval $neval"
|
||||
|
||||
# echo $pythoncm
|
||||
# $pythoncm
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue