borg_velocity/borg_velocity/projection.py
2024-04-24 00:04:01 +02:00

96 lines
No EOL
3.4 KiB
Python

import jax.numpy as jnp
import jax.scipy.ndimage
import jax
from functools import partial
from astropy.coordinates import SkyCoord
import astropy.units as apu
@partial(jax.jit, static_argnames=['order'])
def jit_map_coordinates(image: jnp.ndarray, coords:jnp.ndarray, order: int) -> jnp.ndarray:
"""
Jitted version of jax.scipy.ndimage.map_coordinates
Args:
- image (jnp.ndarray) - The input array
- coords (jnp.ndarray) - The coordinates at which image is evaluated.
- order (int): order of interpolation (0 <= order <= 5)
Returns:
- map_coordinates (jnp.ndarray) - The result of transforming the input. The shape of the output is derived from that of coordinates by dropping the first axis.
"""
return jax.scipy.ndimage.map_coordinates(image, coords, order=order, mode='wrap')
@partial(jax.jit, static_argnames=['order','use_jitted'])
def interp_field(input_array: jnp.ndarray, coords: jnp.ndarray, L: float, origin: jnp.ndarray, order: int, use_jitted: bool=False) -> jnp.ndarray:
"""
Interpolate an array on a ND-cubic grid to new coordinates linearly
Args:
- input_array (jnp.ndarray): array to be interpolated
- coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at
- L (float): length of box
- origin (jnp.ndarray, shape=(ndim,)): position corresponding to index [0,0,...]
- order (int): order of interpolation (0 <= order <= 5)
Returns:
- out_array (np.ndarray, SHAPE): field evaluated at coords of interest
"""
N = input_array.shape[-1]
# Change coords to index
pos = (coords - jnp.expand_dims(origin, axis=(1,2))) * N / L
# NOTE: jax's 'wrap' is the same as scipy's 'grid-wrap'
if use_jitted:
def fun_to_vmap(arr):
return jit_map_coordinates(arr, pos, order)
else:
def fun_to_vmap(arr):
return jax.scipy.ndimage.map_coordinates(arr, pos, order=order, mode='wrap')
if len(input_array.shape) == coords.shape[0]:
out_array = jit_map_coordinates(input_array, pos, order)
elif len(input_array.shape) == coords.shape[0]+1:
vmap_fun = jax.vmap(fun_to_vmap, in_axes=[0])
out_array = vmap_fun(input_array)
else:
raise NotImplementedError
return out_array
@jax.jit
def project_radial(vec: jnp.ndarray, coords: jnp.ndarray, origin: jnp.ndarray) -> jnp.ndarray:
"""
Project vectors along the radial direction, given by coords
Args:
- vec (jnp.ndarray, shape=(ndim,npoint0,npoint1)): array of vectors to be projected
- coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at
- origin (jnp.ndarray, shape=(ndim,)): coordinates of the observer
Returns:
- vr (jnp.ndarray, shape=(npoint0,npoint1)): vec projected along radial direction
"""
x = coords - jnp.expand_dims(origin, axis=(1,2))
r = jnp.sqrt(jnp.sum(x**2, axis=0))
x = x / jnp.expand_dims(r, axis=0)
vr = jnp.sum(x * vec, axis=0)
return vr
def get_radial_vectors(coord_meas):
c = SkyCoord(x=coord_meas[0], y=coord_meas[1], z=coord_meas[2],
representation_type='cartesian')
RA = c.spherical.lon.degree
DEC = c.spherical.lat.degree
# Get unit vectors along line of sight
r_hat = jnp.array(SkyCoord(ra=RA*apu.deg, dec=DEC*apu.deg).cartesian.xyz)
r_hat = jnp.expand_dims(r_hat.T, axis=0)
return r_hat