JaxPM/jaxpm/plotting.py
Wassim KABALAN df8602b318 jaxdecomp proto (#21)
* adding example of distributed solution

* put back old functgion

* update formatting

* add halo exchange and slice pad

* apply formatting

* implement distributed optimized cic_paint

* Use new cic_paint with halo

* Fix seed for distributed normal

* Wrap interpolation function to avoid all gather

* Return normal order frequencies for single GPU

* add example

* format

* add optimised bench script

* times in ms

* add lpt2

* update benchmark and add slurm

* Visualize only final field

* Update scripts/distributed_pm.py

Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>

* Adjust pencil type for frequencies

* fix painting issue with slabs

* Shared operation in fourrier space now take inverted sharding axis for
slabs

* add assert to make pyright happy

* adjust test for hpc-plotter

* add PMWD test

* bench

* format

* added github workflow

* fix formatting from main

* Update for jaxDecomp pure JAX

* revert single halo extent change

* update for latest jaxDecomp

* remove fourrier_space in autoshmap

* make normal_field work with single controller

* format

* make distributed pm work in single controller

* merge bench_pm

* update to leapfrog

* add a strict dependency on jaxdecomp

* global mesh no longer needed

* kernels.py no longer uses global mesh

* quick fix in distributed

* pm.py no longer uses global mesh

* painting.py no longer uses global mesh

* update demo script

* quick fix in kernels

* quick fix in distributed

* update demo

* merge hugos LPT2 code

* format

* Small fix

* format

* remove duplicate get_ode_fn

* update visualizer

* update compensate CIC

* By default check_rep is false for shard_map

* remove experimental distributed code

* update PGDCorrection and neural ode to use new fft3d

* jaxDecomp pfft3d promotes to complex automatically

* remove deprecated stuff

* fix painting issue with read_cic

* use jnp interp instead of jc interp

* delete old slurms

* add notebook examples

* apply formatting

* add distributed zeros

* fix code in LPT2

* jit cic_paint

* update notebooks

* apply formating

* get local shape and zeros can be used by users

* add a user facing function to create uniform particle grid

* use jax interp instead of jax_cosmo

* use float64 for enmeshing

* Allow applying weights with relative cic paint

* Weights can be traced

* remove script folder

* update example notebooks

* delete outdated design file

* add readme for tutorials

* update readme

* fix small error

* forgot particles in multi host

* clarifying why cic_paint_dx is slower

* clarifying the halo size dependence on the box size

* ability to choose snapshots number with MultiHost script

* Adding animation notebook

* Put plotting in package

* Add finite difference laplace kernel + powerspec functions from Hugo

Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>

* Put plotting utils in package

* By default use absoulute painting with

* update code

* update notebooks

* add tests

* Upgrade setup.py to pyproject

* Format

* format tests

* update test dependencies

* add test workflow

* fix deprecated FftType in jaxpm.kernels

* Add aboucaud comments

* JAX version is 0.4.35 until Diffrax new release

* add numpy explicitly as dependency for tests

* fix install order for tests

* add numpy to be installed

* enforce no build isolation for fastpm

* pip install jaxpm test without build isolation

* bump jaxdecomp version

* revert test workflow

* remove outdated tests

---------

Co-authored-by: EiffL <fr.eiffel@gmail.com>
Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com>
Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr>
Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com>
Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
2024-12-20 05:44:02 -05:00

129 lines
4.6 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
def plot_fields(fields_dict, sum_over=None):
"""
Plots sum projections of 3D fields along different axes,
slicing only the first `sum_over` elements along each axis.
Args:
- fields: list of 3D arrays representing fields to plot
- names: list of names for each field, used in titles
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_rows = len(fields_dict)
nb_cols = 3
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
def plot_subplots(proj_axis, field, row, title):
slicing = [slice(None)] * field.ndim
slicing[proj_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over the specified axis and plot
axes[row, proj_axis].imshow(
field[slicing].sum(axis=proj_axis) + 1,
cmap='magma',
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
axes[row, proj_axis].set_xlabel('Mpc/h')
axes[row, proj_axis].set_ylabel('Mpc/h')
axes[row, proj_axis].set_title(title)
# Plot each field across the three axes
for i, (name, field) in enumerate(fields_dict.items()):
for proj_axis in range(3):
plot_subplots(proj_axis, field, i,
f"{name} projection {proj_axis}")
plt.tight_layout()
plt.show()
def plot_fields_single_projection(fields_dict,
sum_over=None,
project_axis=0,
vmin=None,
vmax=None,
colorbar=False):
"""
Plots a single projection (along axis 0) of 3D fields in a grid,
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
Args:
- fields_dict: dictionary where keys are field names and values are 3D arrays
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
"""
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
nb_fields = len(fields_dict)
nb_cols = 4 # Set number of images per row
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
fig, axes = plt.subplots(nb_rows,
nb_cols,
figsize=(5 * nb_cols, 5 * nb_rows))
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
for i, (name, field) in enumerate(fields_dict.items()):
row, col = divmod(i, nb_cols)
# Define the slice for the 0-axis projection
slicing = [slice(None)] * field.ndim
slicing[project_axis] = slice(None, sum_over)
slicing = tuple(slicing)
# Sum projection over axis 0 and plot
a = axes[row,
col].imshow(field[slicing].sum(axis=project_axis) + 1,
cmap='magma',
extent=[0, field.shape[1], 0, field.shape[2]],
vmin=vmin,
vmax=vmax)
axes[row, col].set_xlabel('Mpc/h')
axes[row, col].set_ylabel('Mpc/h')
axes[row, col].set_title(f"{name} projection 0")
if colorbar:
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
# Remove any empty subplots
for j in range(i + 1, nb_rows * nb_cols):
fig.delaxes(axes.flatten()[j])
plt.tight_layout()
plt.show()
def stack_slices(array):
"""
Stacks 2D slices of an array into a single array based on provided partition dimensions.
Args:
- array_slices: a 2D list of array slices (list of lists format) where
array_slices[i][j] is the slice located at row i, column j in the grid.
- pdims: a tuple representing the grid dimensions (rows, columns).
Returns:
- A single array constructed by stacking the slices.
"""
# Initialize an empty list to store the vertically stacked rows
pdims = array.sharding.mesh.devices.shape
field_slices = []
# Iterate over rows in pdims[0]
for i in range(pdims[0]):
row_slices = []
# Iterate over columns in pdims[1]
for j in range(pdims[1]):
slice_index = i * pdims[0] + j
row_slices.append(array.addressable_data(slice_index))
# Stack the current row of slices vertically
stacked_row = np.hstack(row_slices)
field_slices.append(stacked_row)
# Stack all rows horizontally to form the full array
full_array = np.vstack(field_slices)
return full_array