Add Spherical lensing example

This commit is contained in:
Wassim Kabalan 2025-06-28 19:25:14 +02:00
parent 2d21985279
commit f6d547e31f
5 changed files with 1048 additions and 381 deletions

View file

@ -1,22 +1,32 @@
import jax
import jax.numpy as jnp
import jax_cosmo
import jax_cosmo as jc
import jax_cosmo.constants as constants
from jax.scipy.ndimage import map_coordinates
from jaxpm.painting import cic_paint_2d
from jaxpm.distributed import uniform_particles
from jaxpm.painting import cic_paint, cic_paint_2d, cic_paint_dx
from jaxpm.spherical import paint_spherical
from jaxpm.utils import gaussian_smoothing
def density_plane(positions,
box_shape,
center,
width,
plane_resolution,
smoothing_sigma=None):
""" Extacts a density plane from the simulation
"""
def density_plane_fn(box_shape,
box_size,
density_plane_width,
density_plane_npix,
sharding=None):
def f(t, y, args):
positions = y[0]
cosmo = args
nx, ny, nz = box_shape
# Converts time t to comoving distance in voxel coordinates
w = density_plane_width / box_size[2] * box_shape[2]
center = jc.background.radial_comoving_distance(
cosmo, t) / box_size[2] * box_shape[2]
positions = uniform_particles(box_shape) + positions
xy = positions[..., :2]
d = positions[..., 2]
@ -24,59 +34,148 @@ def density_plane(positions,
xy = jnp.mod(xy, nx)
# Rescaling positions to target grid
xy = xy / nx * plane_resolution
xy = xy / nx * density_plane_npix
# Selecting only particles that fall inside the volume of interest
weight = jnp.where(
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
weight = jnp.where((d > (center - w / 2)) & (d <= (center + w / 2)),
1.0, 0.0)
# Painting density plane
density_plane = cic_paint_2d(
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
zero_mesh = jnp.zeros([density_plane_npix, density_plane_npix])
# Apply sharding in order to recover sharding when taking gradients
if sharding is not None:
xy = jax.lax.with_sharding_constraint(xy, sharding)
# Apply CIC painting
density_plane = cic_paint_2d(zero_mesh, xy, weight)
# Apply density normalization
density_plane = density_plane / ((nx / plane_resolution) *
(ny / plane_resolution) * (width))
# Apply Gaussian smoothing if requested
if smoothing_sigma is not None:
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
density_plane = density_plane / ((nx / density_plane_npix) *
(ny / density_plane_npix) * w)
return density_plane
return f
def convergence_Born(cosmo, density_planes, coords, z_source):
def spherical_density_fn(box_shape,
box_size,
nside,
fov,
center_radec,
observer_position,
d_R,
sharding=None):
def f(t, y, args):
positions = y[0]
nx, ny, nz = box_shape
bx, by, bz = box_size
cosmo = args
# Converts time t to comoving distance in voxel coordinates
w = d_R / box_size[2] * box_shape[2]
center = ((jc.background.radial_comoving_distance(cosmo, t)) / bz) * nz
# Apply sharding in order to recover sharding when taking gradients
if sharding is not None:
positions = jax.lax.with_sharding_constraint(positions, sharding)
density_mesh = cic_paint_dx(positions)
# Project to spherical map
spherical_map = paint_spherical(density_mesh, nside, fov, center_radec,
observer_position, box_size, center,
d_R)
return spherical_map
return f
# ==========================================================
# Weak Lensing Born Approximation
# ==========================================================
def convergence_Born(cosmo, density_planes, r, a, dx, dz, coords, z_source):
"""
Compute the Born convergence
Args:
cosmo: `Cosmology`, cosmology object.
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
name: `string`, name of the operation.
Returns:
`Tensor` of shape [batch_size, N, Nz], of convergence values.
Compute Born-approximation lensing convergence maps.
Parameters
----------
cosmo : jc.Cosmology
Cosmology object.
density_planes : ndarray
3D array of lensing density planes [nx, ny, n_planes].
r, a : ndarray
Comoving distances and scale factors per plane.
dx : float
Pixel scale.
dz : float
Redshift bin width.
coords : ndarray
Angular coordinates grid [2, N, 2] in radians.
z_source : ndarray
Source redshifts.
Returns
-------
convergence : ndarray
2D convergence map for each source redshift.
"""
# Compute constant prefactor:
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jax_cosmo.background.radial_comoving_distance(
cosmo, 1 / (1 + z_source))
r_s = jc.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
n_planes = len(r)
convergence = 0
for entry in density_planes:
r = entry['r']
a = entry['a']
p = entry['plane']
dx = entry['dx']
dz = entry['dz']
# Normalize density planes
density_normalization = dz * r / a
def scan_fn(carry, i):
density_planes, a, r = carry
p = density_planes[:, :, i]
density_normalization = dz * r[i] / a[i]
p = (p - p.mean()) * constant_factor * density_normalization
# Interpolate at the density plane coordinates
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
im = map_coordinates(p, coords * r[i] / dx - 0.5, order=1, mode="wrap")
convergence += im * jnp.clip(1. -
(r / r_s), 0, 1000).reshape([-1, 1, 1])
return carry, im * jnp.clip(1.0 -
(r[i] / r_s), 0, 1000).reshape([-1, 1, 1])
return convergence
_, convergence = jax.lax.scan(scan_fn, (density_planes, a, r),
jnp.arange(n_planes))
return convergence.sum(axis=0)
def spherical_convergence_Born(cosmo, density_planes, r, a, nside, z_source):
"""
Compute Born-approximation lensing convergence maps on a sphere.
Parameters
----------
cosmo : jc.Cosmology
Cosmology object.
density_planes : ndarray
3D array of lensing density planes [n_planes, npix].
r, a : ndarray
Comoving distances and scale factors per plane.
nside : int
Healpix nside parameter.
z_source : ndarray
Source redshifts.
Returns
-------
convergence : ndarray
2D convergence map for each source redshift.
"""
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
# Compute comoving distance of source galaxies
r_s = jc.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
n_planes = len(r)
def scan_fn(carry, i):
density_planes, a, r = carry
p = density_planes[i, :]
density_normalization = r[i] / a[
i] # This normalization needs to be checked
p = (p - p.mean()) * constant_factor * density_normalization
return carry, p * jnp.clip(1.0 -
(r[i] / r_s), 0, 1000).reshape([-1, 1])
_, convergence = jax.lax.scan(scan_fn, (density_planes, a, r),
jnp.arange(n_planes))
return convergence.sum(axis=0)

133
jaxpm/ode.py Normal file
View file

@ -0,0 +1,133 @@
from jaxpm.growth import E, Gf, dGfa, gp
from jaxpm.growth import growth_factor as Gp
from jaxpm.pm import pm_forces
def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sharding=None):
def drift(a, vel, args):
"""
state is a tuple (position, velocities)
"""
cosmo = args[0]
# Get the time steps
t0 = a
t1 = a + dt0
# Set the scale factors
ai = t0
ac = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
af = t1
#drift_contr = (Gp(cosmo, af) - Gp(cosmo, ai)) / gp(cosmo, ac)
drift_contr = (af - ai )/ ac
# Computes the update of position (drift)
dpos = 1 / (ac**3 * E(cosmo, ac)) * vel
return dpos * (drift_contr / dt0)
def kick(a, pos, args):
"""
state is a tuple (position, velocities)
"""
# Computes the update of velocity (kick)
cosmo = args
# Get the time steps
t0 = a
t1 = t0 + dt0
t2 = t1 + dt0
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
t1t2 = (t1 * t2) ** 0.5 # Geometric mean of t1 and t2
# Set the scale factors
ac = t1
forces = (
pm_forces(
pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding,
)
* 1.5
* cosmo.Omega_m
)
# Computes the update of velocity (kick)
dvel = 1.0 / (ac**2 * E(cosmo, ac)) * forces
# First kick control factor
kick_factor_1 = (t1 - t0t1) / t1
#kick_factor_1 = (Gf(cosmo, t1) - Gf(cosmo, t0t1)) / dGfa(cosmo, t1)
# Second kick control factor
kick_factor_2 = (t2 - t1t2) / t2
#kick_factor_2 = (Gf(cosmo, t1t2) - Gf(cosmo, t1)) / dGfa(cosmo, t1)
return dvel * ((kick_factor_1 + kick_factor_2) / dt0)
def first_kick(a, pos, args):
"""
state is a tuple (position, velocities)
"""
# Computes the update of velocity (kick)
cosmo = args
# Get the time steps
t0 = a
t1 = t0 + dt0
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
forces = (
pm_forces(
pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding,
)
* 1.5
* cosmo.Omega_m
)
# Computes the update of velocity (kick)
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
# First kick control factor
kick_factor = (Gf(cosmo, t0t1) - Gf(cosmo, t0)) / dGfa(cosmo, t0)
return dvel * (kick_factor / dt0)
return drift, kick, first_kick
def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None):
def drift(a, vel, args):
"""
state is a tuple (position, velocities)
"""
cosmo = args
# Computes the update of position (drift)
dpos = 1 / (a**3 * E(cosmo, a)) * vel
return dpos
def kick(a, pos, args):
"""
state is a tuple (position, velocities)
"""
# Computes the update of velocity (kick)
cosmo = args
forces = (
pm_forces(
pos,
mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size,
sharding=sharding,
)
* 1.5
* cosmo.Omega_m
)
# Computes the update of velocity (kick)
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
return dvel
return drift, kick

50
jaxpm/spherical.py Normal file
View file

@ -0,0 +1,50 @@
import jax.numpy as jnp
import jax_healpy as jhp
import matplotlib.pyplot as plt
import jax
from functools import partial
import healpy as hp
@partial(jax.jit, static_argnames=('nside', 'fov', 'center_radec' , 'd_R' , 'box_size'))
def paint_spherical(volume, nside, fov, center_radec, observer_position, box_size, R, d_R):
width, height, depth = volume.shape
ra0, dec0 = center_radec
fov_width, fov_height = fov
pixel_scale_x = fov_width / width
pixel_scale_y = fov_height / height
res_deg = jhp.nside2resol(nside, arcmin=True) / 60
if pixel_scale_x > res_deg or pixel_scale_y > res_deg:
print(f"WARNING Pixel scale ({pixel_scale_x:.4f} deg, {pixel_scale_y:.4f} deg) is larger than the Healpy resolution ({res_deg:.4f} deg). Increase the field of view or decrease the nside.")
y_idx, x_idx = jnp.indices((height, width))
ra_grid = ra0 + x_idx * pixel_scale_x
dec_grid = dec0 + y_idx * pixel_scale_y
ra_flat = ra_grid.flatten() * jnp.pi / 180.0
dec_flat = dec_grid.flatten() * jnp.pi / 180.0
R_s = jnp.arange(0 , d_R, 1.0) + R
XYZ = R_s.reshape(-1, 1, 1) * jhp.ang2vec(ra_flat, dec_flat, lonlat=False)
observer_position = jnp.array(observer_position)
# Convert observer position from box units to grid units
observer_position = observer_position / jnp.array(box_size) * jnp.array(volume.shape)
coords = XYZ + jnp.asarray(observer_position)[jnp.newaxis, jnp.newaxis, :]
pixels = jhp.ang2pix(nside, ra_flat, dec_flat, lonlat=False)
npix = jhp.nside2npix(nside)
@partial(jax.vmap, in_axes=(0, None, None))
def interpolate_volume(coords, volume, pixels):
voxels = jax.scipy.ndimage.map_coordinates(volume, coords.T, order=1)
sums = jnp.bincount(pixels, weights=voxels, length=npix)
return sums
sum_map = interpolate_volume(coords, volume, pixels).sum(axis=0)
counts = jnp.bincount(pixels, length=npix)
sum_map = jnp.where(counts > 0, sum_map / counts, jhp.UNSEEN)
return sum_map

View file

@ -1,320 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# **Animating Particle Mesh density fields**\n",
"\n",
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n",
"\n",
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
"\n",
"just like in notebook [**05-MultiHost_PM.ipynb**]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**A few hours later**\n",
"\n",
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
"\n",
"In this example, weve been allocated **32 GPUs split across 4 nodes**.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!squeue -u $USER -o \"%i %D %b\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"del os.environ['VSCODE_PROXY_URI']\n",
"del os.environ['NO_PROXY']\n",
"del os.environ['no_proxy']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Checking Available Compute Resources\n",
"\n",
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-Host Simulation Script with Arguments (reminder)\n",
"\n",
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Heres a breakdown of the key arguments:\n",
"\n",
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
"\n",
"### Running the Multi-Host Simulation Script\n",
"\n",
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n",
"\n",
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n",
"\n",
"The command to run the multi-host simulation with these settings will look something like this:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"\n",
"# Define parameters as variables\n",
"jobid = \"467745\"\n",
"num_processes = 32\n",
"script_name = \"05-MultiHost_PM.py\"\n",
"mesh_shape = (1024, 1024, 1024)\n",
"box_size = (1000., 1000., 1000.)\n",
"halo_size = 128\n",
"solver = \"leapfrog\"\n",
"pdims = (16, 2)\n",
"snapshots = 8\n",
"\n",
"# Build the command as a list, incorporating variables\n",
"command = [\n",
" \"srun\",\n",
" f\"--jobid={jobid}\",\n",
" \"-n\", str(num_processes),\n",
" \"python\", script_name,\n",
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
" \"--halo_size\", str(halo_size),\n",
" \"-s\", solver,\n",
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
" \"--snapshots\", str(snapshots)\n",
"]\n",
"\n",
"# Execute the command as a subprocess\n",
"subprocess.run(command)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Projecting the 3D Density Fields to 2D\n",
"\n",
"To visualize the 3D density fields in 2D, we need to create a projection:\n",
"\n",
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n",
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n",
"\n",
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n",
"\n",
"### Applying the Magma Colormap\n",
"\n",
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n",
"\n",
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n",
" - First, normalize the array to the `[0, 1]` range.\n",
" - Apply the colormap to create RGB images, which will be used for the animation.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import colormaps\n",
"\n",
"# Define a function to project the 3D field to 2D\n",
"def project_to_2d(field):\n",
" sum_over = field.shape[0] // 8\n",
" slicing = [slice(None)] * field.ndim\n",
" slicing[0] = slice(None, sum_over)\n",
" slicing = tuple(slicing)\n",
"\n",
" return field[slicing].sum(axis=0)\n",
"\n",
"\n",
"def apply_colormap(array, cmap_name=\"magma\"):\n",
" cmap = colormaps[cmap_name]\n",
" normalized_array = (array - array.min()) / (array.max() - array.min())\n",
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n",
" return (colored_image * 255).astype(np.uint8)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading and Visualizing Results\n",
"\n",
"After running the multi-host simulation, we load the saved results from disk:\n",
"\n",
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n",
"\n",
"We will now project the fields to 2D maps and apply the color map\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n",
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n",
"ode_solutions = []\n",
"for i in range(8):\n",
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Animating with Manim\n",
"\n",
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from manim import *\n",
"config.media_width = \"100%\"\n",
"config.verbosity = \"WARNING\"\n",
"config.background_color = \"#00000000\" # Transparent background"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining the Animation in Manim\n",
"\n",
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n",
"\n",
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n",
"- **Animation Sequence**:\n",
" - The animation begins with a fade-in of the initial conditions.\n",
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n",
"\n",
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Define the animation in Manim\n",
"class FieldTransition(Scene):\n",
" def construct(self):\n",
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n",
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n",
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n",
"\n",
"\n",
" # Place the images on top of each other initially\n",
" lpt_img.move_to(init_conditions_img)\n",
" for img in snapshots_imgs:\n",
" img.move_to(init_conditions_img)\n",
"\n",
" # Show initial field and then transform between fields\n",
" self.play(FadeIn(init_conditions_img))\n",
" self.wait(0.2)\n",
" self.play(Transform(init_conditions_img, lpt_img))\n",
" self.wait(0.2)\n",
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n",
" self.wait(0.2)\n",
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n",
" self.play(Transform(img1, img2))\n",
" self.wait(0.2)\n",
"\n",
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long