No description
  • Jupyter Notebook 83.3%
  • Python 16.6%
Find a file
2024-07-15 15:34:24 +02:00
example Fixed mass bins for generating benchmark file 2023-10-06 16:09:49 +02:00
notes Sample size dependency check 2024-05-08 17:40:26 +02:00
npe small fixes 2024-07-15 15:34:24 +02:00
scripts updated documentation and example script 2024-07-01 23:16:58 +02:00
test Refactored evaluation script 2024-02-05 16:55:34 +01:00
.gitattributes Finally adding all notebooks 2024-05-06 17:06:39 +02:00
.gitignore correct benchmark generator 2023-05-25 23:36:22 +02:00
cluster_script Corrected cluster script to match new training script 2022-11-30 18:16:46 +01:00
cluster_script_gpu Improved training and cluster scripts 2022-10-27 16:44:51 +02:00
environment.yml Some project infrastructure improvements 2024-02-05 16:56:37 +01:00
environment_gpu.yml Finally adding all notebooks 2024-05-06 17:06:39 +02:00
LICENSE Initial commit 2022-06-17 15:00:15 +02:00
pyproject.toml Some project infrastructure improvements 2024-02-05 16:56:37 +01:00
README.md updated README 2024-07-10 16:34:02 +02:00

PineTree

Continuation of Tom Charnock's neural bias project, named Neural Physical Engine (NPE). For the original publication see https://arxiv.org/abs/1909.06379


It is now extended by the recent publication https://arxiv.org/abs/2407.01391

This repository is the implementation of the physics-informed and generative neural network for halo bias modelling and halo mock production from (approximate) dark matter overdensity fields. The network architecture consists of two parts: A one-convolutional network with symmetric kernels based on the multipole expansion to reduce number of independent weights in the kernel, and a log-normal Gaussian mixture density network that emulated the conditional halo mass function. For more details and results obtained with this model please see the reference mentioned above.

The main code is located in the folder npe/, while the folders scripts/ and notes/ contain a (so far) unsorted collection of scripts and jupyter notebooks that can be used as examples on how to utilise the code. A example training script is provided at the end of this README.

For questions or bugs, please open an issue or e-mail me via simon.ding@iap.fr

Installation

Create the conda environment (can be slow) by running

conda env create -f environment.yml

Then install the repository as package via

conda activate npe
cd npe/
pip install -e .

This assumes that cd npe/ takes you to the root of the repository.

(Currently not working!) Test that everything runs by executing

pytest test

Install optional dependency bias-bench

For benchmarking and plotting routines this repository relies on the bias-bench package. It should be installed into the same conda environment following the instructions here

Troubleshooting

It can happen that for some reason the conda install messes up with some dependencies, in this we can set up the environment using the following steps:

Create a new conda environment with Python3 by running

conda create --name npe_env 'python>=3.9'

First install the core dependencies of this package using

conda install numpy pandas pyaml tqdm notebook 'h5py>=3.9.0' 'matplotlib>=3.6'

If you don't want to run JAX with GPU-support then there is no need to install it anymore since it comes with the flax package. For JAX with GPU, please check the installation description at https://github.com/google/jax#installation. Also, it is best to install flax after having installed jax. Be aware that one needs have a working CUDA and jaxlib.

The working order of installation is by first beginning with the conda dependencies

conda install pandas matplotlib pyaml tqdm h5py

Then we install jaxlib, jax and flax through pip

pip install --upgrade pip
conda install jaxlib

and

# CUDA 12 installation
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# CUDA 11 installation
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

pip install flax

Finally, install tensorflow-probabilities for jax via

pip install -Uq tfp-nightly[jax] > /dev/null

It is important to install the JAX-variant of tensorflow-probabilities since this will be independent of tensorflow. For more information, see https://www.tensorflow.org/probability/examples/TensorFlow_Probability_on_JAX.

In order to run the unit tests please install pytest using

conda install pytest

Extra: Configure git for Jupyter notebooks

In order to avoid checking in all outputs from jupyter notebooks, one can configure a git hook via

git config filter.strip-notebook-output.clean 'jupyter nbconvert --clear-output --inplace --stdin --stdout --log-level=ERROR'

The .gitattributes file with then use this filter whenever a notebook is added via git. This operation will not affect the local notebook state.

Example use case

A minimal training example is provided by scripts/train_pinetree.py with necessary files in the example/ folder. It can be run by

python scripts/train_pinetree.py --model_config example/model_config.yml --iteration_count 10