JaxPM/notebooks/02-Advanced_usage.ipynb
Wassim KABALAN 695e359f56 jaxdecomp proto (#21)
* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
2024-12-20 05:44:02 -05:00

5.1 MiB

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from h5py import File
from scipy.stats import spearmanr

import csiborgtools

%matplotlib inline
%load_ext autoreload
%autoreload 2
In [2]:
paths = csiborgtools.read.Paths(**csiborgtools.paths_glamdring)

# d = np.load(paths.field_interpolated("SDSS", "csiborg2_main", 16817, "density", "SPH", 1024))
In [33]:
survey = csiborgtools.SDSS()(apply_selection=False)
# survey = csiborgtools.SDSSxALFALFA()(apply_selection=False)
In [35]:
for kind in ["main", "random"]:
    x, smooth = csiborgtools.summary.read_interpolated_field(survey, f"csiborg2_{kind}", "density", "SPH", 1024, paths)
    np .savez(f"../data/{survey.name}_{kind}_density_SPH_1024.npz", val=x, smooth_scales=smooth)
Reading fields:   0%|          | 0/20 [00:00<?, ?it/s]Reading fields: 100%|██████████| 20/20 [00:11<00:00,  1.80it/s]
Reading fields: 100%|██████████| 20/20 [00:10<00:00,  1.86it/s]
In [37]:

Out[37]:
(20, 641409, 5)
In [24]:
np.load("../data/SDSS_main_density_SPH_1024.npz")["val"]
Out[24]:
array([[[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]],

       [[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]],

       [[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]],

       ...,

       [[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]],

       [[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]],

       [[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]]], dtype=float32)
In [ ]: