mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-02-22 09:37:11 +00:00
* adding example of distributed solution * put back old functgion * update formatting * add halo exchange and slice pad * apply formatting * implement distributed optimized cic_paint * Use new cic_paint with halo * Fix seed for distributed normal * Wrap interpolation function to avoid all gather * Return normal order frequencies for single GPU * add example * format * add optimised bench script * times in ms * add lpt2 * update benchmark and add slurm * Visualize only final field * Update scripts/distributed_pm.py Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> * Adjust pencil type for frequencies * fix painting issue with slabs * Shared operation in fourrier space now take inverted sharding axis for slabs * add assert to make pyright happy * adjust test for hpc-plotter * add PMWD test * bench * format * added github workflow * fix formatting from main * Update for jaxDecomp pure JAX * revert single halo extent change * update for latest jaxDecomp * remove fourrier_space in autoshmap * make normal_field work with single controller * format * make distributed pm work in single controller * merge bench_pm * update to leapfrog * add a strict dependency on jaxdecomp * global mesh no longer needed * kernels.py no longer uses global mesh * quick fix in distributed * pm.py no longer uses global mesh * painting.py no longer uses global mesh * update demo script * quick fix in kernels * quick fix in distributed * update demo * merge hugos LPT2 code * format * Small fix * format * remove duplicate get_ode_fn * update visualizer * update compensate CIC * By default check_rep is false for shard_map * remove experimental distributed code * update PGDCorrection and neural ode to use new fft3d * jaxDecomp pfft3d promotes to complex automatically * remove deprecated stuff * fix painting issue with read_cic * use jnp interp instead of jc interp * delete old slurms * add notebook examples * apply formatting * add distributed zeros * fix code in LPT2 * jit cic_paint * update notebooks * apply formating * get local shape and zeros can be used by users * add a user facing function to create uniform particle grid * use jax interp instead of jax_cosmo * use float64 for enmeshing * Allow applying weights with relative cic paint * Weights can be traced * remove script folder * update example notebooks * delete outdated design file * add readme for tutorials * update readme * fix small error * forgot particles in multi host * clarifying why cic_paint_dx is slower * clarifying the halo size dependence on the box size * ability to choose snapshots number with MultiHost script * Adding animation notebook * Put plotting in package * Add finite difference laplace kernel + powerspec functions from Hugo Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> * Put plotting utils in package * By default use absoulute painting with * update code * update notebooks * add tests * Upgrade setup.py to pyproject * Format * format tests * update test dependencies * add test workflow * fix deprecated FftType in jaxpm.kernels * Add aboucaud comments * JAX version is 0.4.35 until Diffrax new release * add numpy explicitly as dependency for tests * fix install order for tests * add numpy to be installed * enforce no build isolation for fastpm * pip install jaxpm test without build isolation * bump jaxdecomp version * revert test workflow * remove outdated tests --------- Co-authored-by: EiffL <fr.eiffel@gmail.com> Co-authored-by: Francois Lanusse <EiffL@users.noreply.github.com> Co-authored-by: Wassim KABALAN <wassim@apc.in2p3.fr> Co-authored-by: Hugo Simonfroy <hugo.simonfroy@gmail.com> Former-commit-id: 8c2e823d4669eac712089bf7f85ffb7912e8232d
129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
|
|
def plot_fields(fields_dict, sum_over=None):
|
|
"""
|
|
Plots sum projections of 3D fields along different axes,
|
|
slicing only the first `sum_over` elements along each axis.
|
|
|
|
Args:
|
|
- fields: list of 3D arrays representing fields to plot
|
|
- names: list of names for each field, used in titles
|
|
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
|
|
"""
|
|
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
|
nb_rows = len(fields_dict)
|
|
nb_cols = 3
|
|
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
|
|
|
|
def plot_subplots(proj_axis, field, row, title):
|
|
slicing = [slice(None)] * field.ndim
|
|
slicing[proj_axis] = slice(None, sum_over)
|
|
slicing = tuple(slicing)
|
|
|
|
# Sum projection over the specified axis and plot
|
|
axes[row, proj_axis].imshow(
|
|
field[slicing].sum(axis=proj_axis) + 1,
|
|
cmap='magma',
|
|
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
|
|
axes[row, proj_axis].set_xlabel('Mpc/h')
|
|
axes[row, proj_axis].set_ylabel('Mpc/h')
|
|
axes[row, proj_axis].set_title(title)
|
|
|
|
# Plot each field across the three axes
|
|
for i, (name, field) in enumerate(fields_dict.items()):
|
|
for proj_axis in range(3):
|
|
plot_subplots(proj_axis, field, i,
|
|
f"{name} projection {proj_axis}")
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
|
|
def plot_fields_single_projection(fields_dict,
|
|
sum_over=None,
|
|
project_axis=0,
|
|
vmin=None,
|
|
vmax=None,
|
|
colorbar=False):
|
|
"""
|
|
Plots a single projection (along axis 0) of 3D fields in a grid,
|
|
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
|
|
|
|
Args:
|
|
- fields_dict: dictionary where keys are field names and values are 3D arrays
|
|
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
|
|
"""
|
|
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
|
nb_fields = len(fields_dict)
|
|
nb_cols = 4 # Set number of images per row
|
|
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
|
|
|
|
fig, axes = plt.subplots(nb_rows,
|
|
nb_cols,
|
|
figsize=(5 * nb_cols, 5 * nb_rows))
|
|
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
|
|
|
|
for i, (name, field) in enumerate(fields_dict.items()):
|
|
row, col = divmod(i, nb_cols)
|
|
|
|
# Define the slice for the 0-axis projection
|
|
slicing = [slice(None)] * field.ndim
|
|
slicing[project_axis] = slice(None, sum_over)
|
|
slicing = tuple(slicing)
|
|
|
|
# Sum projection over axis 0 and plot
|
|
a = axes[row,
|
|
col].imshow(field[slicing].sum(axis=project_axis) + 1,
|
|
cmap='magma',
|
|
extent=[0, field.shape[1], 0, field.shape[2]],
|
|
vmin=vmin,
|
|
vmax=vmax)
|
|
axes[row, col].set_xlabel('Mpc/h')
|
|
axes[row, col].set_ylabel('Mpc/h')
|
|
axes[row, col].set_title(f"{name} projection 0")
|
|
if colorbar:
|
|
fig.colorbar(a, ax=axes[row, col], shrink=0.7)
|
|
|
|
# Remove any empty subplots
|
|
for j in range(i + 1, nb_rows * nb_cols):
|
|
fig.delaxes(axes.flatten()[j])
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|
|
|
|
def stack_slices(array):
|
|
"""
|
|
Stacks 2D slices of an array into a single array based on provided partition dimensions.
|
|
|
|
Args:
|
|
- array_slices: a 2D list of array slices (list of lists format) where
|
|
array_slices[i][j] is the slice located at row i, column j in the grid.
|
|
- pdims: a tuple representing the grid dimensions (rows, columns).
|
|
|
|
Returns:
|
|
- A single array constructed by stacking the slices.
|
|
"""
|
|
# Initialize an empty list to store the vertically stacked rows
|
|
pdims = array.sharding.mesh.devices.shape
|
|
|
|
field_slices = []
|
|
|
|
# Iterate over rows in pdims[0]
|
|
for i in range(pdims[0]):
|
|
row_slices = []
|
|
|
|
# Iterate over columns in pdims[1]
|
|
for j in range(pdims[1]):
|
|
slice_index = i * pdims[0] + j
|
|
row_slices.append(array.addressable_data(slice_index))
|
|
# Stack the current row of slices vertically
|
|
stacked_row = np.hstack(row_slices)
|
|
field_slices.append(stacked_row)
|
|
|
|
# Stack all rows horizontally to form the full array
|
|
full_array = np.vstack(field_slices)
|
|
|
|
return full_array
|