This commit is contained in:
Wassim Kabalan 2025-07-04 19:32:42 +02:00
parent c485a3aed0
commit 60f16a9180

View file

@ -1,48 +1,51 @@
from functools import partial
import healpy as hp
import jax
import jax.numpy as jnp
import jax_healpy as jhp
import numpy as np
import healpy as hp
def gaussian_kernel_angle(gamma, sigma):
"""Gaussian kernel for angular separations.
Parameters
----------
gamma : float or array
Angular separation in radians
sigma : float
Kernel width parameter
Returns
-------
kernel_value : float or array
Gaussian kernel value
"""
return np.exp(-(gamma**2)/(2*sigma**2))/(2*np.pi*sigma**2)
return np.exp(-(gamma**2) / (2 * sigma**2)) / (2 * np.pi * sigma**2)
def sigma_from_h(h, R_center):
"""Convert smoothing length h to angular sigma.
Parameters
----------
h : float
Smoothing length
R_center : float
Center distance of the shell
Returns
-------
sigma : float
Angular sigma parameter
"""
return h/R_center
return h / R_center
def pixels_within_radius(nside, vec, gamma_cut):
"""Find HEALPix pixels within angular radius.
Parameters
----------
nside : int
@ -51,7 +54,7 @@ def pixels_within_radius(nside, vec, gamma_cut):
Unit 3-vector of the particle
gamma_cut : float
Angular radius cutoff in radians
Returns
-------
pixel_indices : array
@ -59,14 +62,15 @@ def pixels_within_radius(nside, vec, gamma_cut):
"""
return hp.query_disc(nside, vec, gamma_cut, inclusive=True, nest=False)
def angsep_vecs(vec1, vec2):
"""Compute angular separation between unit vectors.
Parameters
----------
vec1, vec2 : array, shape (3,)
Unit 3-vectors
Returns
-------
gamma : float
@ -74,14 +78,15 @@ def angsep_vecs(vec1, vec2):
"""
return np.arccos(np.clip(np.dot(vec1, vec2), -1.0, 1.0))
def _healpix_cic_weights(nside, theta, phi):
"""
Compute CIC (Cloud-in-Cell) weights for HEALPix pixels.
This implements a simplified differentiable CIC scheme for spherical coordinates.
Instead of using exact HEALPix face projections, we use a smooth approximation
based on angular distances to nearby pixels.
Parameters
----------
nside : int
@ -90,7 +95,7 @@ def _healpix_cic_weights(nside, theta, phi):
Colatitude in radians (0 to pi)
phi : array_like
Longitude in radians (0 to 2pi)
Returns
-------
pixel_indices : array, shape (4, N)
@ -100,64 +105,64 @@ def _healpix_cic_weights(nside, theta, phi):
"""
# Convert to HEALPix pixel coordinates
host_pixels = jhp.ang2pix(nside, theta, phi)
# Get pixel centers for the host pixel
theta_host, phi_host = jhp.pix2ang(nside, host_pixels)
# Estimate pixel size in radians
pixel_size = jnp.sqrt(4 * jnp.pi / (12 * nside**2))
# Compute offsets from pixel center
dtheta = theta - theta_host
dphi = phi - phi_host
# Handle phi wraparound
dphi = jnp.where(dphi > jnp.pi, dphi - 2*jnp.pi, dphi)
dphi = jnp.where(dphi < -jnp.pi, dphi + 2*jnp.pi, dphi)
dphi = jnp.where(dphi > jnp.pi, dphi - 2 * jnp.pi, dphi)
dphi = jnp.where(dphi < -jnp.pi, dphi + 2 * jnp.pi, dphi)
# Normalize to [0, 1] fractional coordinates within pixel
# Use a smooth approximation for differentiability
delta_u = jnp.clip(0.5 + dtheta / pixel_size, 0.0, 1.0 - 1e-7)
delta_v = jnp.clip(0.5 + dphi / pixel_size, 0.0, 1.0 - 1e-7)
# Create 4 neighbor pixels by using small angular offsets
# This is a simplified approach - we create neighbors by shifting theta/phi
offset = pixel_size * 0.25 # Quarter pixel offset
# Define the 4 corners relative to host pixel
theta_neighbors = jnp.array([
theta_host - offset, # SW
theta_host - offset, # SE
theta_host + offset, # NW
theta_host + offset # NE
theta_host + offset # NE
])
phi_neighbors = jnp.array([
phi_host - offset, # SW
phi_host + offset, # SE
phi_host - offset, # NW
phi_host + offset # NE
phi_host + offset # NE
])
# Convert neighbor coordinates to pixel indices
# Clamp theta to valid range
theta_neighbors = jnp.clip(theta_neighbors, 1e-8, jnp.pi - 1e-8)
# Get pixel indices for each neighbor (broadcasting over particles)
pixel_indices = jnp.array([
jhp.ang2pix(nside, theta_neighbors[i], phi_neighbors[i])
jhp.ang2pix(nside, theta_neighbors[i], phi_neighbors[i])
for i in range(4)
])
# Compute CIC weights for the 4 corners
# Order: SW, SE, NW, NE
weights = jnp.array([
(1 - delta_u) * (1 - delta_v), # SW
(1 - delta_u) * delta_v, # SE
delta_u * (1 - delta_v), # NW
delta_u * delta_v # NE
(1 - delta_u) * delta_v, # SE
delta_u * (1 - delta_v), # NW
delta_u * delta_v # NE
])
return pixel_indices, weights
@ -171,10 +176,10 @@ def paint_particles_spherical_cic(positions,
weights=None):
"""
Paint particles onto HEALPix spherical maps using differentiable CIC scheme.
This implements the Cloud-in-Cell (CIC) mass-assignment scheme for spherical
coordinates, making it differentiable unlike the NGP scheme.
Parameters
----------
positions : ndarray, shape (..., 3)
@ -191,7 +196,7 @@ def paint_particles_spherical_cic(positions,
Shape of the simulation mesh (nx, ny, nz)
weights : ndarray, optional
Particle weights (default: uniform weights)
Returns
-------
healpix_map : ndarray
@ -199,73 +204,75 @@ def paint_particles_spherical_cic(positions,
"""
if weights is None:
weights = jnp.ones(positions.shape[:-1])
# Convert particle positions to physical coordinates
positions = positions * jnp.array(box_size) / jnp.array(mesh_shape)
# Compute relative positions from observer
rel_positions = positions - jnp.asarray(observer_position)
# Convert to spherical coordinates
x, y, z = rel_positions[..., 0], rel_positions[..., 1], rel_positions[..., 2]
x, y, z = rel_positions[..., 0], rel_positions[..., 1], rel_positions[...,
2]
# Comoving distance from observer
r = jnp.sqrt(x**2 + y**2 + z**2)
# Apply distance cuts
distance_mask = (r >= R_min) & (r <= R_max)
# Compute angular coordinates
theta = jnp.arccos(jnp.clip(z / (r + 1e-10), -1, 1))
phi = jnp.arctan2(y, x)
# Apply distance mask to weights
masked_weights = (weights * distance_mask).flatten()
# Get CIC weights for each particle
pixel_indices, cic_weights = _healpix_cic_weights(nside, theta.flatten(), phi.flatten())
pixel_indices, cic_weights = _healpix_cic_weights(nside, theta.flatten(),
phi.flatten())
# Initialize HEALPix map
npix = jhp.nside2npix(nside)
healpix_map = jnp.zeros(npix)
# Apply CIC weights to each of the 4 neighbors
for i in range(4):
# Get pixel indices and weights for this neighbor
pix_idx = pixel_indices[i] # Shape: (N,)
cic_weight = cic_weights[i] # Shape: (N,)
# Combine distance mask and CIC weights
combined_weights = masked_weights * cic_weight
# Add contributions using scatter_add for differentiability
healpix_map = healpix_map.at[pix_idx].add(combined_weights)
# Calculate volume per pixel in spherical shell
pixel_solid_angle = 4 * jnp.pi / npix # steradians per pixel
R_center = 0.5 * (R_min + R_max)
shell_thickness = R_max - R_min
shell_volume_per_pixel = pixel_solid_angle * R_center**2 * shell_thickness
# Convert particle counts to density (particles per unit volume)
healpix_map = healpix_map / shell_volume_per_pixel
return healpix_map
def paint_particles_spherical_ngp(positions,
nside,
observer_position,
R_min,
R_max,
box_size,
mesh_shape,
weights=None):
nside,
observer_position,
R_min,
R_max,
box_size,
mesh_shape,
weights=None):
"""
Paint particles onto HEALPix spherical maps using Nearest Grid Point (NGP) scheme.
This is a non-differentiable method that assigns particles to the nearest pixel.
Parameters
----------
positions : ndarray, shape (..., 3)
@ -282,7 +289,7 @@ def paint_particles_spherical_ngp(positions,
Shape of the simulation mesh (nx, ny, nz)
weights : ndarray, optional
Particle weights (default: uniform weights)
Returns
-------
healpix_map : ndarray
@ -300,7 +307,8 @@ def paint_particles_spherical_ngp(positions,
rel_positions = positions - jnp.asarray(observer_position)
# Convert to spherical coordinates
x, y, z = rel_positions[..., 0], rel_positions[..., 1], rel_positions[..., 2]
x, y, z = rel_positions[..., 0], rel_positions[..., 1], rel_positions[...,
2]
# Comoving distance from observer
r = jnp.sqrt(x**2 + y**2 + z**2)
@ -322,7 +330,7 @@ def paint_particles_spherical_ngp(positions,
# Bin particles into HEALPix pixels
npix = jhp.nside2npix(nside)
healpix_map = jnp.bincount(pixels, weights=masked_weights, length=npix)
# Calculate volume per pixel in spherical shell
pixel_solid_angle = 4 * jnp.pi / npix # steradians per pixel
R_center = 0.5 * (R_min + R_max)
@ -347,10 +355,10 @@ def paint_particles_spherical_rbf(positions,
gamma_cut_factor=3.0):
"""
Paint particles onto HEALPix spherical maps using Radial Basis Function (RBF) scheme.
This implements a smooth RBF kernel for particle painting, using numpy/healpy
for non-differentiable but potentially more accurate calculations.
Parameters
----------
positions : ndarray, shape (..., 3)
@ -371,7 +379,7 @@ def paint_particles_spherical_rbf(positions,
Fixed smoothing width parameter. If None, uses adaptive smoothing.
gamma_cut_factor : float, optional
Cutoff factor for kernel support (default: 3.0)
Returns
-------
healpix_map : ndarray
@ -379,88 +387,89 @@ def paint_particles_spherical_rbf(positions,
"""
if weights is None:
weights = np.ones(positions.shape[:-1])
# Convert from JAX arrays to numpy if needed
positions = np.asarray(positions)
weights = np.asarray(weights)
observer_position = np.asarray(observer_position)
# Convert particle positions to physical coordinates
positions = positions * np.array(box_size) / np.array(mesh_shape)
# Compute relative positions from observer
rel_positions = positions - observer_position
# Convert to spherical coordinates
x, y, z = rel_positions[..., 0], rel_positions[..., 1], rel_positions[..., 2]
x, y, z = rel_positions[..., 0], rel_positions[..., 1], rel_positions[...,
2]
# Comoving distance from observer
r = np.sqrt(x**2 + y**2 + z**2)
# Apply distance cuts
distance_mask = (r >= R_min) & (r <= R_max)
# Initialize HEALPix map
npix = hp.nside2npix(nside)
healpix_map = np.zeros(npix)
# Flatten arrays for easier processing
positions_flat = positions.reshape(-1, 3)
rel_positions_flat = rel_positions.reshape(-1, 3)
r_flat = r.flatten()
weights_flat = weights.flatten()
distance_mask_flat = distance_mask.flatten()
# Calculate shell center distance for sigma calculation
R_center = 0.5 * (R_min + R_max)
# Process each particle
for i in range(len(positions_flat)):
if not distance_mask_flat[i]:
continue
# Get particle properties
r_p = r_flat[i]
m_p = weights_flat[i]
# Convert to unit vector
vec_p = rel_positions_flat[i] / r_p
# Determine smoothing width
if sigma_fixed is not None:
sigma_p = sigma_fixed
else:
# For now, use fixed sigma - adaptive can be added later
sigma_p = sigma_fixed if sigma_fixed is not None else 0.1 # Default value
# Cutoff radius for kernel support
gamma_cut = gamma_cut_factor * sigma_p
# Find pixels within cutoff radius
pix_indices = pixels_within_radius(nside, vec_p, gamma_cut)
# Calculate pixel area (constant for all pixels)
pixel_area = 4 * np.pi / npix
# Process each pixel within the cutoff
for pix in pix_indices:
# Get pixel center direction
vec_i = np.array(hp.pix2vec(nside, pix))
# Calculate angular separation
gamma = angsep_vecs(vec_p, vec_i)
# Calculate kernel weight
w = gaussian_kernel_angle(gamma, sigma_p) * pixel_area
# Add contribution to the map
healpix_map[pix] += m_p * w
# Apply shell-volume normalization
pixel_area = 4 * np.pi / npix
shell_vol_per_pix = pixel_area * (R_max**3 - R_min**3) / 3
healpix_map /= shell_vol_per_pix
return healpix_map
@ -496,7 +505,7 @@ def paint_particles_spherical(positions,
weights : ndarray, optional
Particle weights (default: uniform weights)
method : str, optional
Painting method: 'cic' for Cloud-in-Cell (differentiable), 'ngp' for
Painting method: 'cic' for Cloud-in-Cell (differentiable), 'ngp' for
Nearest Grid Point (non-differentiable), or 'rbf' for Radial Basis Function.
Default is 'cic'.
sigma_fixed : float, optional
@ -508,25 +517,27 @@ def paint_particles_spherical(positions,
-------
healpix_map : ndarray
HEALPix density map
"""
"""
# Check particle density warning
total_particles = jnp.prod(jnp.array(positions.shape[:-1])) # Total number of particles
total_particles = jnp.prod(jnp.array(
positions.shape[:-1])) # Total number of particles
npix = jhp.nside2npix(nside) # Total HEALPix pixels
# Estimate fraction of particles in the shell
# Approximate shell volume vs box volume
box_diagonal = jnp.sqrt(jnp.sum(jnp.array(box_size)**2))
max_distance = box_diagonal / 2 # Maximum possible distance from center
shell_volume_fraction = ((R_max**3 - R_min**3) / 3) / (max_distance**3 / 3)
shell_volume_fraction = jnp.minimum(shell_volume_fraction, 1.0) # Cap at 1.0
shell_volume_fraction = jnp.minimum(shell_volume_fraction,
1.0) # Cap at 1.0
estimated_particles_in_shell = total_particles * shell_volume_fraction
particles_per_pixel = estimated_particles_in_shell / npix
# Warn if particles per pixel is too low (threshold: 1 particle per pixel)
min_particles_per_pixel = 1.0
jax.lax.cond(particles_per_pixel < min_particles_per_pixel,
lambda: jax.debug.print(
jax.lax.cond(
particles_per_pixel < min_particles_per_pixel, lambda: jax.debug.print(
"WARNING: Low particle density detected! "
"Estimated {particles_per_pixel} particles per pixel (threshold: {min_threshold}). "
"This may result in shot noise and low statistical power. "
@ -537,28 +548,27 @@ def paint_particles_spherical(positions,
nside=nside,
total_particles=total_particles,
shell_fraction=shell_volume_fraction,
npix=npix
),
lambda: None
)
npix=npix), lambda: None)
# Choose method
if method.upper() == 'CIC':
# Apply JIT compilation for CIC method
jit_func = jax.jit(paint_particles_spherical_cic, static_argnames=('nside', 'mesh_shape', 'box_size'))
return jit_func(
positions, nside, observer_position, R_min, R_max,
box_size, mesh_shape, weights)
jit_func = jax.jit(paint_particles_spherical_cic,
static_argnames=('nside', 'mesh_shape', 'box_size'))
return jit_func(positions, nside, observer_position, R_min, R_max,
box_size, mesh_shape, weights)
elif method.upper() == 'NGP':
# Apply JIT compilation for NGP method
jit_func = jax.jit(paint_particles_spherical_ngp, static_argnames=('nside', 'mesh_shape', 'box_size'))
return jit_func(
positions, nside, observer_position, R_min, R_max,
box_size, mesh_shape, weights)
# Apply JIT compilation for NGP method
jit_func = jax.jit(paint_particles_spherical_ngp,
static_argnames=('nside', 'mesh_shape', 'box_size'))
return jit_func(positions, nside, observer_position, R_min, R_max,
box_size, mesh_shape, weights)
elif method.upper() == 'RBF':
# RBF method uses numpy/healpy - no JIT compilation
return paint_particles_spherical_rbf(
positions, nside, observer_position, R_min, R_max,
box_size, mesh_shape, weights, sigma_fixed, gamma_cut_factor)
return paint_particles_spherical_rbf(positions, nside,
observer_position, R_min, R_max,
box_size, mesh_shape, weights,
sigma_fixed, gamma_cut_factor)
else:
raise ValueError(f"Unknown method '{method}'. Choose 'cic', 'ngp', or 'rbf'.")
raise ValueError(
f"Unknown method '{method}'. Choose 'cic', 'ngp', or 'rbf'.")