From 9d0b047f06a3b289e20fc46cdae8caa2ed72669f Mon Sep 17 00:00:00 2001 From: Francois Lanusse Date: Sat, 28 Jun 2025 23:21:31 +0200 Subject: [PATCH] fixing formatting --- jaxpm/ode.py | 82 ++++++++++++++++++++++------------------------ jaxpm/spherical.py | 25 +++++++++----- 2 files changed, 56 insertions(+), 51 deletions(-) diff --git a/jaxpm/ode.py b/jaxpm/ode.py index 9a093c2..f552e8e 100644 --- a/jaxpm/ode.py +++ b/jaxpm/ode.py @@ -3,7 +3,12 @@ from jaxpm.growth import growth_factor as Gp from jaxpm.pm import pm_forces -def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sharding=None): +def symplectic_fpm_ode(mesh_shape, + dt0, + paint_absolute_pos=True, + halo_size=0, + sharding=None): + def drift(a, vel, args): """ state is a tuple (position, velocities) @@ -14,11 +19,11 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh t1 = a + dt0 # Set the scale factors ai = t0 - ac = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1 + ac = (t0 * t1)**0.5 # Geometric mean of t0 and t1 af = t1 #drift_contr = (Gp(cosmo, af) - Gp(cosmo, ai)) / gp(cosmo, ac) - drift_contr = (af - ai )/ ac + drift_contr = (af - ai) / ac # Computes the update of position (drift) dpos = 1 / (ac**3 * E(cosmo, ac)) * vel @@ -34,27 +39,23 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh t0 = a t1 = t0 + dt0 t2 = t1 + dt0 - t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1 - t1t2 = (t1 * t2) ** 0.5 # Geometric mean of t1 and t2 + t0t1 = (t0 * t1)**0.5 # Geometric mean of t0 and t1 + t1t2 = (t1 * t2)**0.5 # Geometric mean of t1 and t2 # Set the scale factors ac = t1 - forces = ( - pm_forces( - pos, - mesh_shape=mesh_shape, - paint_absolute_pos=paint_absolute_pos, - halo_size=halo_size, - sharding=sharding, - ) - * 1.5 - * cosmo.Omega_m - ) + forces = (pm_forces( + pos, + mesh_shape=mesh_shape, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding, + ) * 1.5 * cosmo.Omega_m) # Computes the update of velocity (kick) dvel = 1.0 / (ac**2 * E(cosmo, ac)) * forces # First kick control factor - kick_factor_1 = (t1 - t0t1) / t1 + kick_factor_1 = (t1 - t0t1) / t1 #kick_factor_1 = (Gf(cosmo, t1) - Gf(cosmo, t0t1)) / dGfa(cosmo, t1) # Second kick control factor kick_factor_2 = (t2 - t1t2) / t2 @@ -71,19 +72,15 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh # Get the time steps t0 = a t1 = t0 + dt0 - t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1 + t0t1 = (t0 * t1)**0.5 # Geometric mean of t0 and t1 - forces = ( - pm_forces( - pos, - mesh_shape=mesh_shape, - paint_absolute_pos=paint_absolute_pos, - halo_size=halo_size, - sharding=sharding, - ) - * 1.5 - * cosmo.Omega_m - ) + forces = (pm_forces( + pos, + mesh_shape=mesh_shape, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding, + ) * 1.5 * cosmo.Omega_m) # Computes the update of velocity (kick) dvel = 1.0 / (a**2 * E(cosmo, a)) * forces @@ -94,7 +91,12 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh return drift, kick, first_kick -def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None): + +def symplectic_ode(mesh_shape, + paint_absolute_pos=True, + halo_size=0, + sharding=None): + def drift(a, vel, args): """ state is a tuple (position, velocities) @@ -113,21 +115,17 @@ def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=No cosmo = args - forces = ( - pm_forces( - pos, - mesh_shape=mesh_shape, - paint_absolute_pos=paint_absolute_pos, - halo_size=halo_size, - sharding=sharding, - ) - * 1.5 - * cosmo.Omega_m - ) + forces = (pm_forces( + pos, + mesh_shape=mesh_shape, + paint_absolute_pos=paint_absolute_pos, + halo_size=halo_size, + sharding=sharding, + ) * 1.5 * cosmo.Omega_m) # Computes the update of velocity (kick) dvel = 1.0 / (a**2 * E(cosmo, a)) * forces return dvel - return drift, kick \ No newline at end of file + return drift, kick diff --git a/jaxpm/spherical.py b/jaxpm/spherical.py index bdace1a..ab3c558 100644 --- a/jaxpm/spherical.py +++ b/jaxpm/spherical.py @@ -1,12 +1,16 @@ +from functools import partial + +import healpy as hp +import jax import jax.numpy as jnp import jax_healpy as jhp import matplotlib.pyplot as plt -import jax -from functools import partial -import healpy as hp -@partial(jax.jit, static_argnames=('nside', 'fov', 'center_radec' , 'd_R' , 'box_size')) -def paint_spherical(volume, nside, fov, center_radec, observer_position, box_size, R, d_R): + +@partial(jax.jit, + static_argnames=('nside', 'fov', 'center_radec', 'd_R', 'box_size')) +def paint_spherical(volume, nside, fov, center_radec, observer_position, + box_size, R, d_R): width, height, depth = volume.shape ra0, dec0 = center_radec fov_width, fov_height = fov @@ -16,7 +20,9 @@ def paint_spherical(volume, nside, fov, center_radec, observer_position, box_siz res_deg = jhp.nside2resol(nside, arcmin=True) / 60 if pixel_scale_x > res_deg or pixel_scale_y > res_deg: - print(f"WARNING Pixel scale ({pixel_scale_x:.4f} deg, {pixel_scale_y:.4f} deg) is larger than the Healpy resolution ({res_deg:.4f} deg). Increase the field of view or decrease the nside.") + print( + f"WARNING Pixel scale ({pixel_scale_x:.4f} deg, {pixel_scale_y:.4f} deg) is larger than the Healpy resolution ({res_deg:.4f} deg). Increase the field of view or decrease the nside." + ) y_idx, x_idx = jnp.indices((height, width)) ra_grid = ra0 + x_idx * pixel_scale_x @@ -24,13 +30,14 @@ def paint_spherical(volume, nside, fov, center_radec, observer_position, box_siz ra_flat = ra_grid.flatten() * jnp.pi / 180.0 dec_flat = dec_grid.flatten() * jnp.pi / 180.0 - R_s = jnp.arange(0 , d_R, 1.0) + R + R_s = jnp.arange(0, d_R, 1.0) + R XYZ = R_s.reshape(-1, 1, 1) * jhp.ang2vec(ra_flat, dec_flat, lonlat=False) observer_position = jnp.array(observer_position) # Convert observer position from box units to grid units - observer_position = observer_position / jnp.array(box_size) * jnp.array(volume.shape) - + observer_position = observer_position / jnp.array(box_size) * jnp.array( + volume.shape) + coords = XYZ + jnp.asarray(observer_position)[jnp.newaxis, jnp.newaxis, :] pixels = jhp.ang2pix(nside, ra_flat, dec_flat, lonlat=False)