mirror of
https://github.com/DifferentiableUniverseInitiative/JaxPM.git
synced 2025-06-29 08:31:11 +00:00
fixing formatting
This commit is contained in:
parent
1fcfc2aef2
commit
9d0b047f06
2 changed files with 56 additions and 51 deletions
82
jaxpm/ode.py
82
jaxpm/ode.py
|
@ -3,7 +3,12 @@ from jaxpm.growth import growth_factor as Gp
|
||||||
from jaxpm.pm import pm_forces
|
from jaxpm.pm import pm_forces
|
||||||
|
|
||||||
|
|
||||||
def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sharding=None):
|
def symplectic_fpm_ode(mesh_shape,
|
||||||
|
dt0,
|
||||||
|
paint_absolute_pos=True,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None):
|
||||||
|
|
||||||
def drift(a, vel, args):
|
def drift(a, vel, args):
|
||||||
"""
|
"""
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
|
@ -14,11 +19,11 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
|
||||||
t1 = a + dt0
|
t1 = a + dt0
|
||||||
# Set the scale factors
|
# Set the scale factors
|
||||||
ai = t0
|
ai = t0
|
||||||
ac = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
|
ac = (t0 * t1)**0.5 # Geometric mean of t0 and t1
|
||||||
af = t1
|
af = t1
|
||||||
|
|
||||||
#drift_contr = (Gp(cosmo, af) - Gp(cosmo, ai)) / gp(cosmo, ac)
|
#drift_contr = (Gp(cosmo, af) - Gp(cosmo, ai)) / gp(cosmo, ac)
|
||||||
drift_contr = (af - ai )/ ac
|
drift_contr = (af - ai) / ac
|
||||||
# Computes the update of position (drift)
|
# Computes the update of position (drift)
|
||||||
dpos = 1 / (ac**3 * E(cosmo, ac)) * vel
|
dpos = 1 / (ac**3 * E(cosmo, ac)) * vel
|
||||||
|
|
||||||
|
@ -34,27 +39,23 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
|
||||||
t0 = a
|
t0 = a
|
||||||
t1 = t0 + dt0
|
t1 = t0 + dt0
|
||||||
t2 = t1 + dt0
|
t2 = t1 + dt0
|
||||||
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
|
t0t1 = (t0 * t1)**0.5 # Geometric mean of t0 and t1
|
||||||
t1t2 = (t1 * t2) ** 0.5 # Geometric mean of t1 and t2
|
t1t2 = (t1 * t2)**0.5 # Geometric mean of t1 and t2
|
||||||
# Set the scale factors
|
# Set the scale factors
|
||||||
ac = t1
|
ac = t1
|
||||||
|
|
||||||
forces = (
|
forces = (pm_forces(
|
||||||
pm_forces(
|
pos,
|
||||||
pos,
|
mesh_shape=mesh_shape,
|
||||||
mesh_shape=mesh_shape,
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
paint_absolute_pos=paint_absolute_pos,
|
halo_size=halo_size,
|
||||||
halo_size=halo_size,
|
sharding=sharding,
|
||||||
sharding=sharding,
|
) * 1.5 * cosmo.Omega_m)
|
||||||
)
|
|
||||||
* 1.5
|
|
||||||
* cosmo.Omega_m
|
|
||||||
)
|
|
||||||
|
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1.0 / (ac**2 * E(cosmo, ac)) * forces
|
dvel = 1.0 / (ac**2 * E(cosmo, ac)) * forces
|
||||||
# First kick control factor
|
# First kick control factor
|
||||||
kick_factor_1 = (t1 - t0t1) / t1
|
kick_factor_1 = (t1 - t0t1) / t1
|
||||||
#kick_factor_1 = (Gf(cosmo, t1) - Gf(cosmo, t0t1)) / dGfa(cosmo, t1)
|
#kick_factor_1 = (Gf(cosmo, t1) - Gf(cosmo, t0t1)) / dGfa(cosmo, t1)
|
||||||
# Second kick control factor
|
# Second kick control factor
|
||||||
kick_factor_2 = (t2 - t1t2) / t2
|
kick_factor_2 = (t2 - t1t2) / t2
|
||||||
|
@ -71,19 +72,15 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
|
||||||
# Get the time steps
|
# Get the time steps
|
||||||
t0 = a
|
t0 = a
|
||||||
t1 = t0 + dt0
|
t1 = t0 + dt0
|
||||||
t0t1 = (t0 * t1) ** 0.5 # Geometric mean of t0 and t1
|
t0t1 = (t0 * t1)**0.5 # Geometric mean of t0 and t1
|
||||||
|
|
||||||
forces = (
|
forces = (pm_forces(
|
||||||
pm_forces(
|
pos,
|
||||||
pos,
|
mesh_shape=mesh_shape,
|
||||||
mesh_shape=mesh_shape,
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
paint_absolute_pos=paint_absolute_pos,
|
halo_size=halo_size,
|
||||||
halo_size=halo_size,
|
sharding=sharding,
|
||||||
sharding=sharding,
|
) * 1.5 * cosmo.Omega_m)
|
||||||
)
|
|
||||||
* 1.5
|
|
||||||
* cosmo.Omega_m
|
|
||||||
)
|
|
||||||
|
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
|
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
|
||||||
|
@ -94,7 +91,12 @@ def symplectic_fpm_ode(mesh_shape, dt0, paint_absolute_pos=True, halo_size=0, sh
|
||||||
|
|
||||||
return drift, kick, first_kick
|
return drift, kick, first_kick
|
||||||
|
|
||||||
def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=None):
|
|
||||||
|
def symplectic_ode(mesh_shape,
|
||||||
|
paint_absolute_pos=True,
|
||||||
|
halo_size=0,
|
||||||
|
sharding=None):
|
||||||
|
|
||||||
def drift(a, vel, args):
|
def drift(a, vel, args):
|
||||||
"""
|
"""
|
||||||
state is a tuple (position, velocities)
|
state is a tuple (position, velocities)
|
||||||
|
@ -113,21 +115,17 @@ def symplectic_ode(mesh_shape, paint_absolute_pos=True, halo_size=0, sharding=No
|
||||||
|
|
||||||
cosmo = args
|
cosmo = args
|
||||||
|
|
||||||
forces = (
|
forces = (pm_forces(
|
||||||
pm_forces(
|
pos,
|
||||||
pos,
|
mesh_shape=mesh_shape,
|
||||||
mesh_shape=mesh_shape,
|
paint_absolute_pos=paint_absolute_pos,
|
||||||
paint_absolute_pos=paint_absolute_pos,
|
halo_size=halo_size,
|
||||||
halo_size=halo_size,
|
sharding=sharding,
|
||||||
sharding=sharding,
|
) * 1.5 * cosmo.Omega_m)
|
||||||
)
|
|
||||||
* 1.5
|
|
||||||
* cosmo.Omega_m
|
|
||||||
)
|
|
||||||
|
|
||||||
# Computes the update of velocity (kick)
|
# Computes the update of velocity (kick)
|
||||||
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
|
dvel = 1.0 / (a**2 * E(cosmo, a)) * forces
|
||||||
|
|
||||||
return dvel
|
return dvel
|
||||||
|
|
||||||
return drift, kick
|
return drift, kick
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import healpy as hp
|
||||||
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax_healpy as jhp
|
import jax_healpy as jhp
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import jax
|
|
||||||
from functools import partial
|
|
||||||
import healpy as hp
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnames=('nside', 'fov', 'center_radec' , 'd_R' , 'box_size'))
|
|
||||||
def paint_spherical(volume, nside, fov, center_radec, observer_position, box_size, R, d_R):
|
@partial(jax.jit,
|
||||||
|
static_argnames=('nside', 'fov', 'center_radec', 'd_R', 'box_size'))
|
||||||
|
def paint_spherical(volume, nside, fov, center_radec, observer_position,
|
||||||
|
box_size, R, d_R):
|
||||||
width, height, depth = volume.shape
|
width, height, depth = volume.shape
|
||||||
ra0, dec0 = center_radec
|
ra0, dec0 = center_radec
|
||||||
fov_width, fov_height = fov
|
fov_width, fov_height = fov
|
||||||
|
@ -16,7 +20,9 @@ def paint_spherical(volume, nside, fov, center_radec, observer_position, box_siz
|
||||||
|
|
||||||
res_deg = jhp.nside2resol(nside, arcmin=True) / 60
|
res_deg = jhp.nside2resol(nside, arcmin=True) / 60
|
||||||
if pixel_scale_x > res_deg or pixel_scale_y > res_deg:
|
if pixel_scale_x > res_deg or pixel_scale_y > res_deg:
|
||||||
print(f"WARNING Pixel scale ({pixel_scale_x:.4f} deg, {pixel_scale_y:.4f} deg) is larger than the Healpy resolution ({res_deg:.4f} deg). Increase the field of view or decrease the nside.")
|
print(
|
||||||
|
f"WARNING Pixel scale ({pixel_scale_x:.4f} deg, {pixel_scale_y:.4f} deg) is larger than the Healpy resolution ({res_deg:.4f} deg). Increase the field of view or decrease the nside."
|
||||||
|
)
|
||||||
|
|
||||||
y_idx, x_idx = jnp.indices((height, width))
|
y_idx, x_idx = jnp.indices((height, width))
|
||||||
ra_grid = ra0 + x_idx * pixel_scale_x
|
ra_grid = ra0 + x_idx * pixel_scale_x
|
||||||
|
@ -24,13 +30,14 @@ def paint_spherical(volume, nside, fov, center_radec, observer_position, box_siz
|
||||||
|
|
||||||
ra_flat = ra_grid.flatten() * jnp.pi / 180.0
|
ra_flat = ra_grid.flatten() * jnp.pi / 180.0
|
||||||
dec_flat = dec_grid.flatten() * jnp.pi / 180.0
|
dec_flat = dec_grid.flatten() * jnp.pi / 180.0
|
||||||
R_s = jnp.arange(0 , d_R, 1.0) + R
|
R_s = jnp.arange(0, d_R, 1.0) + R
|
||||||
|
|
||||||
XYZ = R_s.reshape(-1, 1, 1) * jhp.ang2vec(ra_flat, dec_flat, lonlat=False)
|
XYZ = R_s.reshape(-1, 1, 1) * jhp.ang2vec(ra_flat, dec_flat, lonlat=False)
|
||||||
observer_position = jnp.array(observer_position)
|
observer_position = jnp.array(observer_position)
|
||||||
# Convert observer position from box units to grid units
|
# Convert observer position from box units to grid units
|
||||||
observer_position = observer_position / jnp.array(box_size) * jnp.array(volume.shape)
|
observer_position = observer_position / jnp.array(box_size) * jnp.array(
|
||||||
|
volume.shape)
|
||||||
|
|
||||||
coords = XYZ + jnp.asarray(observer_position)[jnp.newaxis, jnp.newaxis, :]
|
coords = XYZ + jnp.asarray(observer_position)[jnp.newaxis, jnp.newaxis, :]
|
||||||
|
|
||||||
pixels = jhp.ang2pix(nside, ra_flat, dec_flat, lonlat=False)
|
pixels = jhp.ang2pix(nside, ra_flat, dec_flat, lonlat=False)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue