mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 08:31:11 +00:00
Add Spherical lensing example
This commit is contained in:
parent
2d21985279
commit
f6d547e31f
5 changed files with 1048 additions and 381 deletions
221
jaxpm/lensing.py
221
jaxpm/lensing.py
|
@ -1,82 +1,181 @@
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_cosmo
|
import jax_cosmo
|
||||||
|
import jax_cosmo as jc
|
||||||
import jax_cosmo.constants as constants
|
import jax_cosmo.constants as constants
|
||||||
from jax.scipy.ndimage import map_coordinates
|
from jax.scipy.ndimage import map_coordinates
|
||||||
|
|
||||||
from jaxpm.painting import cic_paint_2d
|
from jaxpm.distributed import uniform_particles
|
||||||
|
from jaxpm.painting import cic_paint, cic_paint_2d, cic_paint_dx
|
||||||
|
from jaxpm.spherical import paint_spherical
|
||||||
from jaxpm.utils import gaussian_smoothing
|
from jaxpm.utils import gaussian_smoothing
|
||||||
|
|
||||||
|
|
||||||
def density_plane(positions,
|
def density_plane_fn(box_shape,
|
||||||
box_shape,
|
box_size,
|
||||||
center,
|
density_plane_width,
|
||||||
width,
|
density_plane_npix,
|
||||||
plane_resolution,
|
sharding=None):
|
||||||
smoothing_sigma=None):
|
|
||||||
""" Extacts a density plane from the simulation
|
def f(t, y, args):
|
||||||
|
positions = y[0]
|
||||||
|
cosmo = args
|
||||||
|
nx, ny, nz = box_shape
|
||||||
|
|
||||||
|
# Converts time t to comoving distance in voxel coordinates
|
||||||
|
w = density_plane_width / box_size[2] * box_shape[2]
|
||||||
|
center = jc.background.radial_comoving_distance(
|
||||||
|
cosmo, t) / box_size[2] * box_shape[2]
|
||||||
|
positions = uniform_particles(box_shape) + positions
|
||||||
|
xy = positions[..., :2]
|
||||||
|
d = positions[..., 2]
|
||||||
|
|
||||||
|
# Apply 2d periodic conditions
|
||||||
|
xy = jnp.mod(xy, nx)
|
||||||
|
|
||||||
|
# Rescaling positions to target grid
|
||||||
|
xy = xy / nx * density_plane_npix
|
||||||
|
# Selecting only particles that fall inside the volume of interest
|
||||||
|
weight = jnp.where((d > (center - w / 2)) & (d <= (center + w / 2)),
|
||||||
|
1.0, 0.0)
|
||||||
|
# Painting density plane
|
||||||
|
zero_mesh = jnp.zeros([density_plane_npix, density_plane_npix])
|
||||||
|
# Apply sharding in order to recover sharding when taking gradients
|
||||||
|
if sharding is not None:
|
||||||
|
xy = jax.lax.with_sharding_constraint(xy, sharding)
|
||||||
|
# Apply CIC painting
|
||||||
|
density_plane = cic_paint_2d(zero_mesh, xy, weight)
|
||||||
|
|
||||||
|
# Apply density normalization
|
||||||
|
density_plane = density_plane / ((nx / density_plane_npix) *
|
||||||
|
(ny / density_plane_npix) * w)
|
||||||
|
return density_plane
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def spherical_density_fn(box_shape,
|
||||||
|
box_size,
|
||||||
|
nside,
|
||||||
|
fov,
|
||||||
|
center_radec,
|
||||||
|
observer_position,
|
||||||
|
d_R,
|
||||||
|
sharding=None):
|
||||||
|
|
||||||
|
def f(t, y, args):
|
||||||
|
positions = y[0]
|
||||||
|
nx, ny, nz = box_shape
|
||||||
|
bx, by, bz = box_size
|
||||||
|
cosmo = args
|
||||||
|
# Converts time t to comoving distance in voxel coordinates
|
||||||
|
w = d_R / box_size[2] * box_shape[2]
|
||||||
|
center = ((jc.background.radial_comoving_distance(cosmo, t)) / bz) * nz
|
||||||
|
|
||||||
|
# Apply sharding in order to recover sharding when taking gradients
|
||||||
|
if sharding is not None:
|
||||||
|
positions = jax.lax.with_sharding_constraint(positions, sharding)
|
||||||
|
|
||||||
|
density_mesh = cic_paint_dx(positions)
|
||||||
|
# Project to spherical map
|
||||||
|
spherical_map = paint_spherical(density_mesh, nside, fov, center_radec,
|
||||||
|
observer_position, box_size, center,
|
||||||
|
d_R)
|
||||||
|
return spherical_map
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================================================
|
||||||
|
# Weak Lensing Born Approximation
|
||||||
|
# ==========================================================
|
||||||
|
def convergence_Born(cosmo, density_planes, r, a, dx, dz, coords, z_source):
|
||||||
"""
|
"""
|
||||||
nx, ny, nz = box_shape
|
Compute Born-approximation lensing convergence maps.
|
||||||
xy = positions[..., :2]
|
|
||||||
d = positions[..., 2]
|
|
||||||
|
|
||||||
# Apply 2d periodic conditions
|
Parameters
|
||||||
xy = jnp.mod(xy, nx)
|
----------
|
||||||
|
cosmo : jc.Cosmology
|
||||||
|
Cosmology object.
|
||||||
|
density_planes : ndarray
|
||||||
|
3D array of lensing density planes [nx, ny, n_planes].
|
||||||
|
r, a : ndarray
|
||||||
|
Comoving distances and scale factors per plane.
|
||||||
|
dx : float
|
||||||
|
Pixel scale.
|
||||||
|
dz : float
|
||||||
|
Redshift bin width.
|
||||||
|
coords : ndarray
|
||||||
|
Angular coordinates grid [2, N, 2] in radians.
|
||||||
|
z_source : ndarray
|
||||||
|
Source redshifts.
|
||||||
|
|
||||||
# Rescaling positions to target grid
|
Returns
|
||||||
xy = xy / nx * plane_resolution
|
-------
|
||||||
|
convergence : ndarray
|
||||||
# Selecting only particles that fall inside the volume of interest
|
2D convergence map for each source redshift.
|
||||||
weight = jnp.where(
|
|
||||||
(d > (center - width / 2)) & (d <= (center + width / 2)), 1., 0.)
|
|
||||||
# Painting density plane
|
|
||||||
density_plane = cic_paint_2d(
|
|
||||||
jnp.zeros([plane_resolution, plane_resolution]), xy, weight)
|
|
||||||
|
|
||||||
# Apply density normalization
|
|
||||||
density_plane = density_plane / ((nx / plane_resolution) *
|
|
||||||
(ny / plane_resolution) * (width))
|
|
||||||
|
|
||||||
# Apply Gaussian smoothing if requested
|
|
||||||
if smoothing_sigma is not None:
|
|
||||||
density_plane = gaussian_smoothing(density_plane, smoothing_sigma)
|
|
||||||
|
|
||||||
return density_plane
|
|
||||||
|
|
||||||
|
|
||||||
def convergence_Born(cosmo, density_planes, coords, z_source):
|
|
||||||
"""
|
"""
|
||||||
Compute the Born convergence
|
|
||||||
Args:
|
|
||||||
cosmo: `Cosmology`, cosmology object.
|
|
||||||
density_planes: list of dictionaries (r, a, density_plane, dx, dz), lens planes to use
|
|
||||||
coords: a 3-D array of angular coordinates in radians of N points with shape [batch, N, 2].
|
|
||||||
z_source: 1-D `Tensor` of source redshifts with shape [Nz] .
|
|
||||||
name: `string`, name of the operation.
|
|
||||||
Returns:
|
|
||||||
`Tensor` of shape [batch_size, N, Nz], of convergence values.
|
|
||||||
"""
|
|
||||||
# Compute constant prefactor:
|
|
||||||
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
||||||
# Compute comoving distance of source galaxies
|
# Compute comoving distance of source galaxies
|
||||||
r_s = jax_cosmo.background.radial_comoving_distance(
|
r_s = jc.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
|
||||||
cosmo, 1 / (1 + z_source))
|
n_planes = len(r)
|
||||||
|
|
||||||
convergence = 0
|
def scan_fn(carry, i):
|
||||||
for entry in density_planes:
|
density_planes, a, r = carry
|
||||||
r = entry['r']
|
|
||||||
a = entry['a']
|
p = density_planes[:, :, i]
|
||||||
p = entry['plane']
|
density_normalization = dz * r[i] / a[i]
|
||||||
dx = entry['dx']
|
|
||||||
dz = entry['dz']
|
|
||||||
# Normalize density planes
|
|
||||||
density_normalization = dz * r / a
|
|
||||||
p = (p - p.mean()) * constant_factor * density_normalization
|
p = (p - p.mean()) * constant_factor * density_normalization
|
||||||
|
|
||||||
# Interpolate at the density plane coordinates
|
# Interpolate at the density plane coordinates
|
||||||
im = map_coordinates(p, coords * r / dx - 0.5, order=1, mode="wrap")
|
im = map_coordinates(p, coords * r[i] / dx - 0.5, order=1, mode="wrap")
|
||||||
|
|
||||||
convergence += im * jnp.clip(1. -
|
return carry, im * jnp.clip(1.0 -
|
||||||
(r / r_s), 0, 1000).reshape([-1, 1, 1])
|
(r[i] / r_s), 0, 1000).reshape([-1, 1, 1])
|
||||||
|
|
||||||
return convergence
|
_, convergence = jax.lax.scan(scan_fn, (density_planes, a, r),
|
||||||
|
jnp.arange(n_planes))
|
||||||
|
return convergence.sum(axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def spherical_convergence_Born(cosmo, density_planes, r, a, nside, z_source):
|
||||||
|
"""
|
||||||
|
Compute Born-approximation lensing convergence maps on a sphere.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cosmo : jc.Cosmology
|
||||||
|
Cosmology object.
|
||||||
|
density_planes : ndarray
|
||||||
|
3D array of lensing density planes [n_planes, npix].
|
||||||
|
r, a : ndarray
|
||||||
|
Comoving distances and scale factors per plane.
|
||||||
|
nside : int
|
||||||
|
Healpix nside parameter.
|
||||||
|
z_source : ndarray
|
||||||
|
Source redshifts.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
convergence : ndarray
|
||||||
|
2D convergence map for each source redshift.
|
||||||
|
"""
|
||||||
|
constant_factor = 3 / 2 * cosmo.Omega_m * (constants.H0 / constants.c)**2
|
||||||
|
# Compute comoving distance of source galaxies
|
||||||
|
r_s = jc.background.radial_comoving_distance(cosmo, 1 / (1 + z_source))
|
||||||
|
n_planes = len(r)
|
||||||
|
|
||||||
|
def scan_fn(carry, i):
|
||||||
|
density_planes, a, r = carry
|
||||||
|
|
||||||
|
p = density_planes[i, :]
|
||||||
|
density_normalization = r[i] / a[
|
||||||
|
i] # This normalization needs to be checked
|
||||||
|
p = (p - p.mean()) * constant_factor * density_normalization
|
||||||
|
|
||||||
|
return carry, p * jnp.clip(1.0 -
|
||||||
|
(r[i] / r_s), 0, 1000).reshape([-1, 1])
|
||||||
|
|
||||||
|
_, convergence = jax.lax.scan(scan_fn, (density_planes, a, r),
|
||||||
|
jnp.arange(n_planes))
|
||||||
|
return convergence.sum(axis=0)
|
||||||
|
|
133
jaxpm/ode.py
Normal file
133
jaxpm/ode.py
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
from jaxpm.growth import E, Gf, dGfa, gp
|
||||||
|
from jaxpm.growth import growth_factor as Gp
|
||||||
|
from jaxpm.pm import pm_forces
|
||||||
|
|
||||||
|
|
||||||
|
def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sharding=None):
|
||||||
|
def drift(a, vel, args):
|
||||||
|
"""
|
||||||
|
state is a tuple (position, velocities)
|
||||||
|
"""
|
||||||
|
cosmo = args[0]
|
||||||
|
# Get the time steps
|
||||||
|
t0 = a
|
||||||
|
t1 = a + dt0
|
||||||
|
# Set the scale factors
|
||||||
|
ai = t0
|
||||||
|
ac = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
|
||||||
|
af = t1
|
||||||
|
|
||||||
|
#drift_contr = (Gp(cosmo, af) - Gp(cosmo, ai)) / gp(cosmo, ac)
|
||||||
|
drift_contr = (af - ai )/ ac
|
||||||
|
# Computes the update of position (drift)
|
||||||
|
dpos = 1 / (ac**3 * E(cosmo, ac)) * vel
|
||||||
|
|
||||||
|
return dpos * (drift_contr / dt0)
|
||||||
|
|
||||||
|
def kick(a, pos, args):
|
||||||
|
"""
|
||||||
|
state is a tuple (position, velocities)
|
||||||
|
"""
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
cosmo = args
|
||||||
|
# Get the time steps
|
||||||
|
t0 = a
|
||||||
|
t1 = t0 + dt0
|
||||||
|
t2 = t1 + dt0
|
||||||
|
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
|
||||||
|
t1t2 = (t1 * t2) ** 0.5 # Geometric mean of t1 and t2
|
||||||
|
# Set the scale factors
|
||||||
|
ac = t1
|
||||||
|
|
||||||
|
forces = (
|
||||||
|
pm_forces(
|
||||||
|
pos,
|
||||||
|
mesh_shape=mesh_shape,
|
||||||
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding,
|
||||||
|
)
|
||||||
|
* 1.5
|
||||||
|
* cosmo.Omega_m
|
||||||
|
)
|
||||||
|
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
dvel = 1.0 / (ac**2 * E(cosmo, ac)) * forces
|
||||||
|
# First kick control factor
|
||||||
|
kick_factor_1 = (t1 - t0t1) / t1
|
||||||
|
#kick_factor_1 = (Gf(cosmo, t1) - Gf(cosmo, t0t1)) / dGfa(cosmo, t1)
|
||||||
|
# Second kick control factor
|
||||||
|
kick_factor_2 = (t2 - t1t2) / t2
|
||||||
|
#kick_factor_2 = (Gf(cosmo, t1t2) - Gf(cosmo, t1)) / dGfa(cosmo, t1)
|
||||||
|
|
||||||
|
return dvel * ((kick_factor_1 + kick_factor_2) / dt0)
|
||||||
|
|
||||||
|
def first_kick(a, pos, args):
|
||||||
|
"""
|
||||||
|
state is a tuple (position, velocities)
|
||||||
|
"""
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
cosmo = args
|
||||||
|
# Get the time steps
|
||||||
|
t0 = a
|
||||||
|
t1 = t0 + dt0
|
||||||
|
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
|
||||||
|
|
||||||
|
forces = (
|
||||||
|
pm_forces(
|
||||||
|
pos,
|
||||||
|
mesh_shape=mesh_shape,
|
||||||
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding,
|
||||||
|
)
|
||||||
|
* 1.5
|
||||||
|
* cosmo.Omega_m
|
||||||
|
)
|
||||||
|
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
|
||||||
|
# First kick control factor
|
||||||
|
kick_factor = (Gf(cosmo, t0t1) - Gf(cosmo, t0)) / dGfa(cosmo, t0)
|
||||||
|
|
||||||
|
return dvel * (kick_factor / dt0)
|
||||||
|
|
||||||
|
return drift, kick, first_kick
|
||||||
|
|
||||||
|
def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None):
|
||||||
|
def drift(a, vel, args):
|
||||||
|
"""
|
||||||
|
state is a tuple (position, velocities)
|
||||||
|
"""
|
||||||
|
cosmo = args
|
||||||
|
# Computes the update of position (drift)
|
||||||
|
dpos = 1 / (a**3 * E(cosmo, a)) * vel
|
||||||
|
|
||||||
|
return dpos
|
||||||
|
|
||||||
|
def kick(a, pos, args):
|
||||||
|
"""
|
||||||
|
state is a tuple (position, velocities)
|
||||||
|
"""
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
|
||||||
|
cosmo = args
|
||||||
|
|
||||||
|
forces = (
|
||||||
|
pm_forces(
|
||||||
|
pos,
|
||||||
|
mesh_shape=mesh_shape,
|
||||||
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
|
halo_size=halo_size,
|
||||||
|
sharding=sharding,
|
||||||
|
)
|
||||||
|
* 1.5
|
||||||
|
* cosmo.Omega_m
|
||||||
|
)
|
||||||
|
|
||||||
|
# Computes the update of velocity (kick)
|
||||||
|
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
|
||||||
|
|
||||||
|
return dvel
|
||||||
|
|
||||||
|
return drift, kick
|
50
jaxpm/spherical.py
Normal file
50
jaxpm/spherical.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import jax_healpy as jhp
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import jax
|
||||||
|
from functools import partial
|
||||||
|
import healpy as hp
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnames=('nside', 'fov', 'center_radec' , 'd_R' , 'box_size'))
|
||||||
|
def paint_spherical(volume, nside, fov, center_radec, observer_position, box_size, R, d_R):
|
||||||
|
width, height, depth = volume.shape
|
||||||
|
ra0, dec0 = center_radec
|
||||||
|
fov_width, fov_height = fov
|
||||||
|
|
||||||
|
pixel_scale_x = fov_width / width
|
||||||
|
pixel_scale_y = fov_height / height
|
||||||
|
|
||||||
|
res_deg = jhp.nside2resol(nside, arcmin=True) / 60
|
||||||
|
if pixel_scale_x > res_deg or pixel_scale_y > res_deg:
|
||||||
|
print(f"WARNING Pixel scale ({pixel_scale_x:.4f} deg, {pixel_scale_y:.4f} deg) is larger than the Healpy resolution ({res_deg:.4f} deg). Increase the field of view or decrease the nside.")
|
||||||
|
|
||||||
|
y_idx, x_idx = jnp.indices((height, width))
|
||||||
|
ra_grid = ra0 + x_idx * pixel_scale_x
|
||||||
|
dec_grid = dec0 + y_idx * pixel_scale_y
|
||||||
|
|
||||||
|
ra_flat = ra_grid.flatten() * jnp.pi / 180.0
|
||||||
|
dec_flat = dec_grid.flatten() * jnp.pi / 180.0
|
||||||
|
R_s = jnp.arange(0 , d_R, 1.0) + R
|
||||||
|
|
||||||
|
XYZ = R_s.reshape(-1, 1, 1) * jhp.ang2vec(ra_flat, dec_flat, lonlat=False)
|
||||||
|
observer_position = jnp.array(observer_position)
|
||||||
|
# Convert observer position from box units to grid units
|
||||||
|
observer_position = observer_position / jnp.array(box_size) * jnp.array(volume.shape)
|
||||||
|
|
||||||
|
coords = XYZ + jnp.asarray(observer_position)[jnp.newaxis, jnp.newaxis, :]
|
||||||
|
|
||||||
|
pixels = jhp.ang2pix(nside, ra_flat, dec_flat, lonlat=False)
|
||||||
|
|
||||||
|
npix = jhp.nside2npix(nside)
|
||||||
|
|
||||||
|
@partial(jax.vmap, in_axes=(0, None, None))
|
||||||
|
def interpolate_volume(coords, volume, pixels):
|
||||||
|
voxels = jax.scipy.ndimage.map_coordinates(volume, coords.T, order=1)
|
||||||
|
sums = jnp.bincount(pixels, weights=voxels, length=npix)
|
||||||
|
return sums
|
||||||
|
|
||||||
|
sum_map = interpolate_volume(coords, volume, pixels).sum(axis=0)
|
||||||
|
counts = jnp.bincount(pixels, length=npix)
|
||||||
|
sum_map = jnp.where(counts > 0, sum_map / counts, jhp.UNSEEN)
|
||||||
|
|
||||||
|
return sum_map
|
|
@ -1,320 +0,0 @@
|
||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# **Animating Particle Mesh density fields**\n",
|
|
||||||
"\n",
|
|
||||||
"In this tutorial, we will animate the density field of a particle mesh simulation. We will use the `manim` library to create the animation. \n",
|
|
||||||
"\n",
|
|
||||||
"The density fields are created exactly like in the notebook [**05-MultiHost_PM.ipynb**](05-MultiHost_PM.ipynb) using the same script [**05-MultiHost_PM.py**](05-MultiHost_PM.py)."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"To run a multi-host simulation, you first need to **allocate a job** with `salloc`. This command requests resources on an HPC cluster.\n",
|
|
||||||
"\n",
|
|
||||||
"just like in notebook [**05-MultiHost_PM.ipynb**]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!salloc --account=XXX@a100 -C a100 --gres=gpu:8 --ntasks-per-node=8 --time=00:40:00 --cpus-per-task=8 --hint=nomultithread --qos=qos_gpu-dev --nodes=4 & "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"**A few hours later**\n",
|
|
||||||
"\n",
|
|
||||||
"Use `!squeue -u $USER -o \"%i %D %b\"` to **check the JOB ID** and verify your resource allocation.\n",
|
|
||||||
"\n",
|
|
||||||
"In this example, we’ve been allocated **32 GPUs split across 4 nodes**.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!squeue -u $USER -o \"%i %D %b\""
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Unset the following environment variables, as they can cause issues when using JAX in a distributed setting:"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"del os.environ['VSCODE_PROXY_URI']\n",
|
|
||||||
"del os.environ['NO_PROXY']\n",
|
|
||||||
"del os.environ['no_proxy']"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Checking Available Compute Resources\n",
|
|
||||||
"\n",
|
|
||||||
"Run the following command to initialize JAX distributed computing and display the devices available for this job:\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!srun --jobid=467745 -n 32 python -c \"import jax; jax.distributed.initialize(); print(jax.devices()) if jax.process_index() == 0 else None\""
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Multi-Host Simulation Script with Arguments (reminder)\n",
|
|
||||||
"\n",
|
|
||||||
"This script is nearly identical to the single-host version, with the main addition being the call to `jax.distributed.initialize()` at the start, enabling multi-host parallelism. Here’s a breakdown of the key arguments:\n",
|
|
||||||
"\n",
|
|
||||||
"- **`--pdims`** (`-p`): Specifies processor grid dimensions as two integers, like `16 2` for 16 x 2 device mesh (default is `[1, jax.devices()]`).\n",
|
|
||||||
"- **`--mesh_shape`** (`-m`): Defines the simulation mesh shape as three integers (default is `[512, 512, 512]`).\n",
|
|
||||||
"- **`--box_size`** (`-b`): Sets the physical box size of the simulation as three floating-point values, e.g., `1000. 1000. 1000.` (default is `[500.0, 500.0, 500.0]`).\n",
|
|
||||||
"- **`--halo_size`** (`-H`): Specifies the halo size for boundary overlap across nodes (default is `64`).\n",
|
|
||||||
"- **`--solver`** (`-s`): Chooses the ODE solver (`leapfrog` or `dopri8`). The `leapfrog` solver uses a fixed step size, while `dopri8` is an adaptive Runge-Kutta solver with a PID controller (default is `leapfrog`).\n",
|
|
||||||
"- **`--snapthots`** (`-st`) : Number of snapshots to save (warning, increases memory usage)\n",
|
|
||||||
"\n",
|
|
||||||
"### Running the Multi-Host Simulation Script\n",
|
|
||||||
"\n",
|
|
||||||
"To create a smooth animation, we need a series of closely spaced snapshots to capture the evolution of the density field over time. In this example, we set the number of snapshots to **10** to ensure smooth transitions in the animation.\n",
|
|
||||||
"\n",
|
|
||||||
"Using a larger number of GPUs helps process these snapshots efficiently, especially with a large simulation mesh or high-resolution data. This allows us to achieve both the desired snapshot frequency and the necessary simulation detail without excessive runtime.\n",
|
|
||||||
"\n",
|
|
||||||
"The command to run the multi-host simulation with these settings will look something like this:\n",
|
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import subprocess\n",
|
|
||||||
"\n",
|
|
||||||
"# Define parameters as variables\n",
|
|
||||||
"jobid = \"467745\"\n",
|
|
||||||
"num_processes = 32\n",
|
|
||||||
"script_name = \"05-MultiHost_PM.py\"\n",
|
|
||||||
"mesh_shape = (1024, 1024, 1024)\n",
|
|
||||||
"box_size = (1000., 1000., 1000.)\n",
|
|
||||||
"halo_size = 128\n",
|
|
||||||
"solver = \"leapfrog\"\n",
|
|
||||||
"pdims = (16, 2)\n",
|
|
||||||
"snapshots = 8\n",
|
|
||||||
"\n",
|
|
||||||
"# Build the command as a list, incorporating variables\n",
|
|
||||||
"command = [\n",
|
|
||||||
" \"srun\",\n",
|
|
||||||
" f\"--jobid={jobid}\",\n",
|
|
||||||
" \"-n\", str(num_processes),\n",
|
|
||||||
" \"python\", script_name,\n",
|
|
||||||
" \"--mesh_shape\", str(mesh_shape[0]), str(mesh_shape[1]), str(mesh_shape[2]),\n",
|
|
||||||
" \"--box_size\", str(box_size[0]), str(box_size[1]), str(box_size[2]),\n",
|
|
||||||
" \"--halo_size\", str(halo_size),\n",
|
|
||||||
" \"-s\", solver,\n",
|
|
||||||
" \"--pdims\", str(pdims[0]), str(pdims[1]),\n",
|
|
||||||
" \"--snapshots\", str(snapshots)\n",
|
|
||||||
"]\n",
|
|
||||||
"\n",
|
|
||||||
"# Execute the command as a subprocess\n",
|
|
||||||
"subprocess.run(command)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Projecting the 3D Density Fields to 2D\n",
|
|
||||||
"\n",
|
|
||||||
"To visualize the 3D density fields in 2D, we need to create a projection:\n",
|
|
||||||
"\n",
|
|
||||||
"- **`project_to_2d` Function**: This function reduces the 3D array to 2D by summing over a portion of one axis.\n",
|
|
||||||
" - We sum the top one-eighth of the data along the first axis to capture a slice of the density field.\n",
|
|
||||||
"\n",
|
|
||||||
"- **Creating 2D Projections**: Apply `project_to_2d` to each 3D field (`initial_conditions`, `lpt_displacements`, `ode_solution_0`, and `ode_solution_1`) to get 2D arrays that represent the density fields.\n",
|
|
||||||
"\n",
|
|
||||||
"### Applying the Magma Colormap\n",
|
|
||||||
"\n",
|
|
||||||
"To improve visualization, apply the \"magma\" colormap to each 2D projection:\n",
|
|
||||||
"\n",
|
|
||||||
"- **`apply_colormap` Function**: This function maps values in the 2D array to colors using the \"magma\" colormap.\n",
|
|
||||||
" - First, normalize the array to the `[0, 1]` range.\n",
|
|
||||||
" - Apply the colormap to create RGB images, which will be used for the animation.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from matplotlib import colormaps\n",
|
|
||||||
"\n",
|
|
||||||
"# Define a function to project the 3D field to 2D\n",
|
|
||||||
"def project_to_2d(field):\n",
|
|
||||||
" sum_over = field.shape[0] // 8\n",
|
|
||||||
" slicing = [slice(None)] * field.ndim\n",
|
|
||||||
" slicing[0] = slice(None, sum_over)\n",
|
|
||||||
" slicing = tuple(slicing)\n",
|
|
||||||
"\n",
|
|
||||||
" return field[slicing].sum(axis=0)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def apply_colormap(array, cmap_name=\"magma\"):\n",
|
|
||||||
" cmap = colormaps[cmap_name]\n",
|
|
||||||
" normalized_array = (array - array.min()) / (array.max() - array.min())\n",
|
|
||||||
" colored_image = cmap(normalized_array)[:, :, :3] # Drop alpha channel for RGB\n",
|
|
||||||
" return (colored_image * 255).astype(np.uint8)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Loading and Visualizing Results\n",
|
|
||||||
"\n",
|
|
||||||
"After running the multi-host simulation, we load the saved results from disk:\n",
|
|
||||||
"\n",
|
|
||||||
"- **`initial_conditions.npy`**: Initial conditions for the simulation.\n",
|
|
||||||
"- **`lpt_displacements.npy`**: Linear perturbation displacements.\n",
|
|
||||||
"- **`ode_solution_*.npy`** : Solutions from the ODE solver at each snapshot.\n",
|
|
||||||
"\n",
|
|
||||||
"We will now project the fields to 2D maps and apply the color map\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import numpy as np\n",
|
|
||||||
"\n",
|
|
||||||
"initial_conditions = apply_colormap(project_to_2d(np.load('fields/initial_conditions.npy')))\n",
|
|
||||||
"lpt_displacements = apply_colormap(project_to_2d(np.load('fields/lpt_displacements.npy')))\n",
|
|
||||||
"ode_solutions = []\n",
|
|
||||||
"for i in range(8):\n",
|
|
||||||
" ode_solutions.append(apply_colormap(project_to_2d(np.load(f'fields/ode_solution_{i}.npy'))))"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Animating with Manim\n",
|
|
||||||
"\n",
|
|
||||||
"To create animations with `manim` in a Jupyter notebook, we start by configuring some settings to ensure the output displays correctly and without a background.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from manim import *\n",
|
|
||||||
"config.media_width = \"100%\"\n",
|
|
||||||
"config.verbosity = \"WARNING\"\n",
|
|
||||||
"config.background_color = \"#00000000\" # Transparent background"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"### Defining the Animation in Manim\n",
|
|
||||||
"\n",
|
|
||||||
"This animation class, `FieldTransition`, smoothly transitions through the stages of the particle mesh density field evolution.\n",
|
|
||||||
"\n",
|
|
||||||
"- **Setup**: Each density field snapshot is loaded as an image and aligned for smooth transitions.\n",
|
|
||||||
"- **Animation Sequence**:\n",
|
|
||||||
" - The animation begins with a fade-in of the initial conditions.\n",
|
|
||||||
" - It then transitions through the stages in sequence, showing each snapshot of the density field evolution with brief pauses in between.\n",
|
|
||||||
"\n",
|
|
||||||
"To run the animation, execute `%manim -v WARNING -qm FieldTransition` to render it in the Jupyter Notebook.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Define the animation in Manim\n",
|
|
||||||
"class FieldTransition(Scene):\n",
|
|
||||||
" def construct(self):\n",
|
|
||||||
" init_conditions_img = ImageMobject(initial_conditions).scale(4)\n",
|
|
||||||
" lpt_img = ImageMobject(lpt_displacements).scale(4)\n",
|
|
||||||
" snapshots_imgs = [ImageMobject(sol).scale(4) for sol in ode_solutions]\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
" # Place the images on top of each other initially\n",
|
|
||||||
" lpt_img.move_to(init_conditions_img)\n",
|
|
||||||
" for img in snapshots_imgs:\n",
|
|
||||||
" img.move_to(init_conditions_img)\n",
|
|
||||||
"\n",
|
|
||||||
" # Show initial field and then transform between fields\n",
|
|
||||||
" self.play(FadeIn(init_conditions_img))\n",
|
|
||||||
" self.wait(0.2)\n",
|
|
||||||
" self.play(Transform(init_conditions_img, lpt_img))\n",
|
|
||||||
" self.wait(0.2)\n",
|
|
||||||
" self.play(Transform(lpt_img, snapshots_imgs[0]))\n",
|
|
||||||
" self.wait(0.2)\n",
|
|
||||||
" for img1, img2 in zip(snapshots_imgs, snapshots_imgs[1:]):\n",
|
|
||||||
" self.play(Transform(img1, img2))\n",
|
|
||||||
" self.wait(0.2)\n",
|
|
||||||
"\n",
|
|
||||||
"%manim -v WARNING -qm -o anim.gif --format=gif FieldTransition "
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.10.4"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
705
notebooks/06-RayTracing.ipynb
Normal file
705
notebooks/06-RayTracing.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
Add table
Add a link
Reference in a new issue