fixing formatting

This commit is contained in:
Francois Lanusse 2025-06-28 23:21:31 +02:00
parent 1fcfc2aef2
commit 9d0b047f06
2 changed files with 56 additions and 51 deletions

View file

@ -3,7 +3,12 @@ from jaxpm.growth import growth_factor as Gp
from jaxpm.pm import pm_forces from jaxpm.pm import pm_forces
def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sharding=None): def symplectic_fpm_ode(mesh_shape,
dt0,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def drift(a, vel, args): def drift(a, vel, args):
""" """
state is a tuple (position, velocities) state is a tuple (position, velocities)
@ -39,17 +44,13 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
# Set the scale factors # Set the scale factors
ac = t1 ac = t1
forces = ( forces = (pm_forces(
pm_forces(
pos, pos,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos, paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding, sharding=sharding,
) ) * 1.5 * cosmo.Omega_m)
* 1.5
* cosmo.Omega_m
)
# Computes the update of velocity (kick) # Computes the update of velocity (kick)
dvel = 1.0 / (ac**2 * E(cosmo, ac)) * forces dvel = 1.0 / (ac**2 * E(cosmo, ac)) * forces
@ -73,17 +74,13 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
t1 = t0 + dt0 t1 = t0 + dt0
t0t1 = (t0 * t1)**0.5 # Geometric mean of t0 and t1 t0t1 = (t0 * t1)**0.5 # Geometric mean of t0 and t1
forces = ( forces = (pm_forces(
pm_forces(
pos, pos,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos, paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding, sharding=sharding,
) ) * 1.5 * cosmo.Omega_m)
* 1.5
* cosmo.Omega_m
)
# Computes the update of velocity (kick) # Computes the update of velocity (kick)
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
@ -94,7 +91,12 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
return drift, kick, first_kick return drift, kick, first_kick
def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None):
def symplectic_ode(mesh_shape,
paint_absolute_pos=True,
halo_size=0,
sharding=None):
def drift(a, vel, args): def drift(a, vel, args):
""" """
state is a tuple (position, velocities) state is a tuple (position, velocities)
@ -113,17 +115,13 @@ def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=No
cosmo = args cosmo = args
forces = ( forces = (pm_forces(
pm_forces(
pos, pos,
mesh_shape=mesh_shape, mesh_shape=mesh_shape,
paint_absolute_pos=paint_absolute_pos, paint_absolute_pos=paint_absolute_pos,
halo_size=halo_size, halo_size=halo_size,
sharding=sharding, sharding=sharding,
) ) * 1.5 * cosmo.Omega_m)
* 1.5
* cosmo.Omega_m
)
# Computes the update of velocity (kick) # Computes the update of velocity (kick)
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces dvel = 1.0 / (a**2 * E(cosmo, a)) * forces

View file

@ -1,12 +1,16 @@
from functools import partial
import healpy as hp
import jax
import jax.numpy as jnp import jax.numpy as jnp
import jax_healpy as jhp import jax_healpy as jhp
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import jax
from functools import partial
import healpy as hp
@partial(jax.jit, static_argnames=('nside', 'fov', 'center_radec' , 'd_R' , 'box_size'))
def paint_spherical(volume, nside, fov, center_radec, observer_position, box_size, R, d_R): @partial(jax.jit,
static_argnames=('nside', 'fov', 'center_radec', 'd_R', 'box_size'))
def paint_spherical(volume, nside, fov, center_radec, observer_position,
box_size, R, d_R):
width, height, depth = volume.shape width, height, depth = volume.shape
ra0, dec0 = center_radec ra0, dec0 = center_radec
fov_width, fov_height = fov fov_width, fov_height = fov
@ -16,7 +20,9 @@ def paint_spherical(volume, nside, fov, center_radec, observer_position, box_siz
res_deg = jhp.nside2resol(nside, arcmin=True) / 60 res_deg = jhp.nside2resol(nside, arcmin=True) / 60
if pixel_scale_x > res_deg or pixel_scale_y > res_deg: if pixel_scale_x > res_deg or pixel_scale_y > res_deg:
print(f"WARNING Pixel scale ({pixel_scale_x:.4f} deg, {pixel_scale_y:.4f} deg) is larger than the Healpy resolution ({res_deg:.4f} deg). Increase the field of view or decrease the nside.") print(
f"WARNING Pixel scale ({pixel_scale_x:.4f} deg, {pixel_scale_y:.4f} deg) is larger than the Healpy resolution ({res_deg:.4f} deg). Increase the field of view or decrease the nside."
)
y_idx, x_idx = jnp.indices((height, width)) y_idx, x_idx = jnp.indices((height, width))
ra_grid = ra0 + x_idx * pixel_scale_x ra_grid = ra0 + x_idx * pixel_scale_x
@ -29,7 +35,8 @@ def paint_spherical(volume, nside, fov, center_radec, observer_position, box_siz
XYZ = R_s.reshape(-1, 1, 1) * jhp.ang2vec(ra_flat, dec_flat, lonlat=False) XYZ = R_s.reshape(-1, 1, 1) * jhp.ang2vec(ra_flat, dec_flat, lonlat=False)
observer_position = jnp.array(observer_position) observer_position = jnp.array(observer_position)
# Convert observer position from box units to grid units # Convert observer position from box units to grid units
observer_position = observer_position / jnp.array(box_size) * jnp.array(volume.shape) observer_position = observer_position / jnp.array(box_size) * jnp.array(
volume.shape)
coords = XYZ + jnp.asarray(observer_position)[jnp.newaxis, jnp.newaxis, :] coords = XYZ + jnp.asarray(observer_position)[jnp.newaxis, jnp.newaxis, :]