JaxPM/scripts/eval_decomp_ode.ipynb
Wassim KABALAN 831291c1f9 bench
2024-08-02 23:39:09 +02:00

254 KiB

In [4]:
import numpy as np
import matplotlib.pyplot as plt
In [32]:
pdims=(1 , 1)
#for single gpu
# pdims = (1 , 1)
In [35]:
folder = f'out/final_field/pmwd/1/128_128/{pdims[0]}x{pdims[1]}/lfm/halo_0'
folder = f'out/final_field/jaxpm/1/128_128/{pdims[0]}x{pdims[1]}/LeapfrogMidpoint/halo_32'
folder = f'out/final_field/jaxpm/1/128_128/{pdims[0]}x{pdims[1]}/lpt/halo_32'

only_final_fields = True

init_field_slices = []
field_slices = []
nb_solutions = 1
nb_to_plot = 1
final_slices = []

for _ in range(nb_solutions):
    final_slices.append([])

for i in range(pdims[0]):
    row_init_field = []
    row_field = []
    row_final_field = []
    for _ in range(nb_solutions):
        row_final_field.append([])
    
    for j in range(pdims[1]):
        slice_index = i * pdims[1]  + j 
        if not only_final_fields:
            row_field.append(np.load(f'{folder}/field_{slice_index}.npy'))
            row_init_field.append(np.load(f'{folder}/initial_conditions_{slice_index}.npy'))

        for sol_indx in range((nb_solutions - nb_to_plot) , nb_solutions):
            row_final_field[sol_indx].append(np.load(f'{folder}/final_field_{sol_indx}_{slice_index}.npy'))
    
    if not only_final_fields:
        field_slices.append(np.vstack(row_field))
        init_field_slices.append(np.vstack(row_init_field))

    for sol_indx in range((nb_solutions - nb_to_plot) , nb_solutions):
        final_slices[sol_indx].append(np.vstack(row_final_field[sol_indx]))

if not only_final_fields:
    field = np.hstack(field_slices)
    initial_conditions = np.hstack(init_field_slices)
final_fields = []

for sol_indx in range(nb_solutions - nb_to_plot , nb_solutions):
    final_fields.append(np.hstack(final_slices[sol_indx]))

if not only_final_fields:
    print(field.shape)
    box_size = field.shape
else:
    print(final_fields[-1].shape)
    box_size = final_fields[-1].shape
(128, 128, 128)
In [36]:
sum_over = box_size[0] // 8

# Function to create subplots
def plot_subplots(proj_axis, input , row, axes, title):
    slicing = [slice(None)] * input.ndim
    slicing[proj_axis] = slice(None, sum_over)
    slicing = tuple(slicing)

    # Plot initial conditions
    axes[row, proj_axis].imshow(input[slicing].sum(axis=proj_axis), cmap='magma', extent=[0, box_size + 5, 0, box_size + 5])
    axes[row, proj_axis].set_xlabel('Mpc/h')
    axes[row, proj_axis].set_ylabel('Mpc/h')
    axes[row, proj_axis].set_title(title)

# Initialize figure and axes
if only_final_fields:
    nb_rows = len(final_fields)
    field_start = 0
else:
    nb_rows = 2 + len(final_fields)
    field_start = 2
    
nb_cols = 3
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))

# Plot initial conditions and LPT field for each projection
if not only_final_fields:
    for proj_axis in range(3):
        plot_subplots(proj_axis,initial_conditions ,  0, axes, f'Initial conditions projection {proj_axis}')
        plot_subplots(proj_axis, field ,  1, axes, f'LPT density field projection {proj_axis}')

if len(final_fields) == 1:  # Check if axes is 1-dimensional
    axes = np.expand_dims(axes,axis=0)
# Plot final fields for each projection
for indx, final_field in enumerate(final_fields):
    for proj_axis in range(3):
        slicing = [slice(None)] * final_field.ndim
        slicing[proj_axis] = slice(None, sum_over)
        slicing = tuple(slicing)
        axes[indx + field_start, proj_axis].imshow(final_fields[indx][slicing].sum(axis=proj_axis) + 1, cmap='magma', extent=[0, box_size[0] + 5, 0, box_size[0] + 5])
        axes[indx + field_start, proj_axis].set_xlabel('Mpc/h')
        axes[indx + field_start, proj_axis].set_ylabel('Mpc/h')
        axes[indx + field_start, proj_axis].set_title(f'ODE Step {indx} projection {proj_axis}')

plt.tight_layout()
plt.show()
No description has been provided for this image
In [ ]: