mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-04-07 20:30:54 +00:00
Put plotting in package
This commit is contained in:
parent
0946842fe5
commit
435c7c848f
6 changed files with 159 additions and 37 deletions
120
jaxpm/plotting.py
Normal file
120
jaxpm/plotting.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def plot_fields(fields_dict, sum_over=None):
|
||||||
|
"""
|
||||||
|
Plots sum projections of 3D fields along different axes,
|
||||||
|
slicing only the first `sum_over` elements along each axis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- fields: list of 3D arrays representing fields to plot
|
||||||
|
- names: list of names for each field, used in titles
|
||||||
|
- sum_over: number of slices to sum along each axis (default: fields[0].shape[0] // 8)
|
||||||
|
"""
|
||||||
|
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||||
|
nb_rows = len(fields_dict)
|
||||||
|
nb_cols = 3
|
||||||
|
fig, axes = plt.subplots(nb_rows, nb_cols, figsize=(15, 5 * nb_rows))
|
||||||
|
|
||||||
|
def plot_subplots(proj_axis, field, row, title):
|
||||||
|
slicing = [slice(None)] * field.ndim
|
||||||
|
slicing[proj_axis] = slice(None, sum_over)
|
||||||
|
slicing = tuple(slicing)
|
||||||
|
|
||||||
|
# Sum projection over the specified axis and plot
|
||||||
|
axes[row, proj_axis].imshow(
|
||||||
|
field[slicing].sum(axis=proj_axis) + 1,
|
||||||
|
cmap='magma',
|
||||||
|
extent=[0, field.shape[proj_axis], 0, field.shape[proj_axis]])
|
||||||
|
axes[row, proj_axis].set_xlabel('Mpc/h')
|
||||||
|
axes[row, proj_axis].set_ylabel('Mpc/h')
|
||||||
|
axes[row, proj_axis].set_title(title)
|
||||||
|
|
||||||
|
# Plot each field across the three axes
|
||||||
|
for i, (name, field) in enumerate(fields_dict.items()):
|
||||||
|
for proj_axis in range(3):
|
||||||
|
plot_subplots(proj_axis, field, i,
|
||||||
|
f"{name} projection {proj_axis}")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_fields_single_projection(fields_dict, sum_over=None, project_axis=0):
|
||||||
|
"""
|
||||||
|
Plots a single projection (along axis 0) of 3D fields in a grid,
|
||||||
|
summing over the first `sum_over` elements along the 0-axis, with 4 images per row.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- fields_dict: dictionary where keys are field names and values are 3D arrays
|
||||||
|
- sum_over: number of slices to sum along the projection axis (default: fields[0].shape[0] // 8)
|
||||||
|
"""
|
||||||
|
sum_over = sum_over or list(fields_dict.values())[0].shape[0] // 8
|
||||||
|
nb_fields = len(fields_dict)
|
||||||
|
nb_cols = 4 # Set number of images per row
|
||||||
|
nb_rows = (nb_fields + nb_cols - 1) // nb_cols # Calculate required rows
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(nb_rows,
|
||||||
|
nb_cols,
|
||||||
|
figsize=(5 * nb_cols, 5 * nb_rows))
|
||||||
|
axes = np.atleast_2d(axes) # Ensure axes is always a 2D array
|
||||||
|
|
||||||
|
for i, (name, field) in enumerate(fields_dict.items()):
|
||||||
|
row, col = divmod(i, nb_cols)
|
||||||
|
|
||||||
|
# Define the slice for the 0-axis projection
|
||||||
|
slicing = [slice(None)] * field.ndim
|
||||||
|
slicing[project_axis] = slice(None, sum_over)
|
||||||
|
slicing = tuple(slicing)
|
||||||
|
|
||||||
|
# Sum projection over axis 0 and plot
|
||||||
|
axes[row, col].imshow(field[slicing].sum(axis=project_axis) + 1,
|
||||||
|
cmap='magma',
|
||||||
|
extent=[0, field.shape[1], 0, field.shape[2]])
|
||||||
|
axes[row, col].set_xlabel('Mpc/h')
|
||||||
|
axes[row, col].set_ylabel('Mpc/h')
|
||||||
|
axes[row, col].set_title(f"{name} projection 0")
|
||||||
|
|
||||||
|
# Remove any empty subplots
|
||||||
|
for j in range(i + 1, nb_rows * nb_cols):
|
||||||
|
fig.delaxes(axes.flatten()[j])
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def stack_slices(array):
|
||||||
|
"""
|
||||||
|
Stacks 2D slices of an array into a single array based on provided partition dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- array_slices: a 2D list of array slices (list of lists format) where
|
||||||
|
array_slices[i][j] is the slice located at row i, column j in the grid.
|
||||||
|
- pdims: a tuple representing the grid dimensions (rows, columns).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- A single array constructed by stacking the slices.
|
||||||
|
"""
|
||||||
|
# Initialize an empty list to store the vertically stacked rows
|
||||||
|
pdims = array.sharding.mesh.devices.shape
|
||||||
|
|
||||||
|
field_slices = []
|
||||||
|
|
||||||
|
# Iterate over rows in pdims[0]
|
||||||
|
for i in range(pdims[0]):
|
||||||
|
row_slices = []
|
||||||
|
|
||||||
|
# Iterate over columns in pdims[1]
|
||||||
|
for j in range(pdims[1]):
|
||||||
|
slice_index = i * pdims[0] + j
|
||||||
|
row_slices.append(array.addressable_data(slice_index))
|
||||||
|
# Stack the current row of slices vertically
|
||||||
|
stacked_row = np.hstack(row_slices)
|
||||||
|
field_slices.append(stacked_row)
|
||||||
|
|
||||||
|
# Stack all rows horizontally to form the full array
|
||||||
|
full_array = np.vstack(field_slices)
|
||||||
|
|
||||||
|
return full_array
|
|
@ -63,17 +63,18 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": null,
|
||||||
"id": "281b4d3b",
|
"id": "281b4d3b",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "281b4d3b"
|
"id": "281b4d3b"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"mesh_shape = [256, 256, 256]\n",
|
"mesh_shape = [128, 128, 128]\n",
|
||||||
"box_size = [256., 256., 256.]\n",
|
"box_size = [128., 128., 128.]\n",
|
||||||
"snapshots = jnp.array([0.1, 0.5, 1.0])\n",
|
"snapshots = jnp.array([0.1, 0.5, 1.0])\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"@jax.jit\n",
|
||||||
"def run_simulation(omega_c, sigma8):\n",
|
"def run_simulation(omega_c, sigma8):\n",
|
||||||
" # Create a small function to generate the matter power spectrum\n",
|
" # Create a small function to generate the matter power spectrum\n",
|
||||||
" k = jnp.logspace(-4, 1, 128)\n",
|
" k = jnp.logspace(-4, 1, 128)\n",
|
||||||
|
@ -126,7 +127,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": null,
|
||||||
"id": "4e012ce8",
|
"id": "4e012ce8",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
|
@ -149,7 +150,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
|
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
|
||||||
"for i , field in enumerate(ode_particles[1:]):\n",
|
"for i , field in enumerate(ode_particles[1:]):\n",
|
||||||
|
@ -169,7 +170,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"id": "b71824ed",
|
"id": "b71824ed",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
@ -182,8 +183,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"mesh_shape = [256, 256, 256]\n",
|
"mesh_shape = [128, 128, 128]\n",
|
||||||
"box_size = [256., 256., 256.]\n",
|
"box_size = [128., 128., 128.]\n",
|
||||||
"snapshots = jnp.array([0.1, 0.5, 1.0])\n",
|
"snapshots = jnp.array([0.1, 0.5, 1.0])\n",
|
||||||
"\n",
|
"\n",
|
||||||
"@jax.jit\n",
|
"@jax.jit\n",
|
||||||
|
@ -216,7 +217,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"id": "33b5e684",
|
"id": "33b5e684",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
@ -232,7 +233,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
|
"\n",
|
||||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n",
|
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n",
|
||||||
"for i , field in enumerate(ode_displacements[1:]):\n",
|
"for i , field in enumerate(ode_displacements[1:]):\n",
|
||||||
" fields[f\"field_{i}\"] = cic_paint_dx(field)\n",
|
" fields[f\"field_{i}\"] = cic_paint_dx(field)\n",
|
||||||
|
|
|
@ -50,7 +50,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -63,8 +63,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"mesh_shape = [256, 256, 256]\n",
|
"mesh_shape = [128, 128, 128]\n",
|
||||||
"box_size = [256., 256., 256.]\n",
|
"box_size = [128., 128., 128.]\n",
|
||||||
"snapshots = jnp.array([0.5, 1.0])\n",
|
"snapshots = jnp.array([0.5, 1.0])\n",
|
||||||
"\n",
|
"\n",
|
||||||
"@jax.jit\n",
|
"@jax.jit\n",
|
||||||
|
@ -111,7 +111,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -126,7 +126,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n",
|
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n",
|
||||||
"for i , field in enumerate(ode_solutions):\n",
|
"for i , field in enumerate(ode_solutions):\n",
|
||||||
|
@ -145,7 +145,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -160,8 +160,8 @@
|
||||||
"source": [
|
"source": [
|
||||||
"from functools import partial \n",
|
"from functools import partial \n",
|
||||||
"\n",
|
"\n",
|
||||||
"mesh_shape = [256, 256, 256]\n",
|
"mesh_shape = [128, 128, 128]\n",
|
||||||
"box_size = [256., 256., 256.]\n",
|
"box_size = [128., 128., 128.]\n",
|
||||||
"snapshots = jnp.array([0.5,1.])\n",
|
"snapshots = jnp.array([0.5,1.])\n",
|
||||||
"\n",
|
"\n",
|
||||||
"@partial(jax.jit , static_argnums=(2,))\n",
|
"@partial(jax.jit , static_argnums=(2,))\n",
|
||||||
|
@ -224,7 +224,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -236,8 +236,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"mesh_shape = [256, 256, 256]\n",
|
"mesh_shape = [128, 128, 128]\n",
|
||||||
"box_size = [256., 256., 256.]\n",
|
"box_size = [128., 128., 128.]\n",
|
||||||
"snapshots = jnp.array([0.1 ,0.5, 1.])\n",
|
"snapshots = jnp.array([0.1 ,0.5, 1.])\n",
|
||||||
"\n",
|
"\n",
|
||||||
"@jax.jit\n",
|
"@jax.jit\n",
|
||||||
|
@ -283,7 +283,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -298,7 +298,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
|
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint(jnp.zeros(mesh_shape) ,lpt_particles)}\n",
|
||||||
"for i , field in enumerate(ode_particles[1:]):\n",
|
"for i , field in enumerate(ode_particles[1:]):\n",
|
||||||
|
@ -321,7 +321,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -337,7 +337,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from math import prod\n",
|
"from math import prod\n",
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
||||||
"center3d = (slice(None) , center,center) \n",
|
"center3d = (slice(None) , center,center) \n",
|
||||||
|
@ -364,7 +364,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -380,7 +380,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from math import prod\n",
|
"from math import prod\n",
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
"center = slice(mesh_shape[0] // 4, 3 * mesh_shape[0] // 4 )\n",
|
||||||
"center3d = (slice(None) , center,center) \n",
|
"center3d = (slice(None) , center,center) \n",
|
||||||
|
|
|
@ -283,7 +283,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"id": "4e012ce8",
|
"id": "4e012ce8",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
|
@ -306,7 +306,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n",
|
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n",
|
||||||
"for i , field in enumerate(ode_solutions):\n",
|
"for i , field in enumerate(ode_solutions):\n",
|
||||||
|
@ -342,7 +342,7 @@
|
||||||
"lpt_displacements_g = all_gather(lpt_displacements)\n",
|
"lpt_displacements_g = all_gather(lpt_displacements)\n",
|
||||||
"ode_solutions_g = [all_gather(p) for p in ode_solutions]\n",
|
"ode_solutions_g = [all_gather(p) for p in ode_solutions]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n",
|
"fields = {\"Initial Conditions\" : initial_conditions , \"LPT Field\" : cic_paint_dx(lpt_displacements)}\n",
|
||||||
"for i , field in enumerate(ode_solutions):\n",
|
"for i , field in enumerate(ode_solutions):\n",
|
||||||
|
@ -447,7 +447,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 78,
|
"execution_count": null,
|
||||||
"id": "0acb5253",
|
"id": "0acb5253",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
@ -463,7 +463,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"fields = {\"Initial Conditions\" : initial_conditions_g , \"LPT Field\" : lpt_field_g}\n",
|
"fields = {\"Initial Conditions\" : initial_conditions_g , \"LPT Field\" : lpt_field_g}\n",
|
||||||
"for i , field in enumerate(ode_fields_g):\n",
|
"for i , field in enumerate(ode_fields_g):\n",
|
||||||
|
@ -600,7 +600,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": null,
|
||||||
"id": "59cfba84",
|
"id": "59cfba84",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
@ -652,7 +652,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from math import prod\n",
|
"from math import prod\n",
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"field = ode_solutions[0]\n",
|
"field = ode_solutions[0]\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
|
@ -159,7 +159,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -181,7 +181,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"\n",
|
"\n",
|
||||||
"mesh_shape = 1024\n",
|
"mesh_shape = 1024\n",
|
||||||
"box_size = 1000.\n",
|
"box_size = 1000.\n",
|
||||||
|
|
|
@ -259,7 +259,7 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from visualize import plot_fields_single_projection\n",
|
"from jaxpm.plotting import plot_fields_single_projection\n",
|
||||||
"fields = {\n",
|
"fields = {\n",
|
||||||
" \"Initial Conditions\": initial_conditions,\n",
|
" \"Initial Conditions\": initial_conditions,\n",
|
||||||
" \"LPT Field\": lpt_displacements,\n",
|
" \"LPT Field\": lpt_displacements,\n",
|
||||||
|
|
Loading…
Add table
Reference in a new issue