96 lines
No EOL
3.4 KiB
Python
96 lines
No EOL
3.4 KiB
Python
import jax.numpy as jnp
|
|
import jax.scipy.ndimage
|
|
import jax
|
|
from functools import partial
|
|
from astropy.coordinates import SkyCoord
|
|
import astropy.units as apu
|
|
|
|
@partial(jax.jit, static_argnames=['order'])
|
|
def jit_map_coordinates(image: jnp.ndarray, coords:jnp.ndarray, order: int) -> jnp.ndarray:
|
|
"""
|
|
Jitted version of jax.scipy.ndimage.map_coordinates
|
|
|
|
Args:
|
|
- image (jnp.ndarray) - The input array
|
|
- coords (jnp.ndarray) - The coordinates at which image is evaluated.
|
|
- order (int): order of interpolation (0 <= order <= 5)
|
|
Returns:
|
|
- map_coordinates (jnp.ndarray) - The result of transforming the input. The shape of the output is derived from that of coordinates by dropping the first axis.
|
|
|
|
"""
|
|
return jax.scipy.ndimage.map_coordinates(image, coords, order=order, mode='wrap')
|
|
|
|
|
|
@partial(jax.jit, static_argnames=['order','use_jitted'])
|
|
def interp_field(input_array: jnp.ndarray, coords: jnp.ndarray, L: float, origin: jnp.ndarray, order: int, use_jitted: bool=False) -> jnp.ndarray:
|
|
"""
|
|
Interpolate an array on a ND-cubic grid to new coordinates linearly
|
|
|
|
Args:
|
|
- input_array (jnp.ndarray): array to be interpolated
|
|
- coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at
|
|
- L (float): length of box
|
|
- origin (jnp.ndarray, shape=(ndim,)): position corresponding to index [0,0,...]
|
|
- order (int): order of interpolation (0 <= order <= 5)
|
|
|
|
Returns:
|
|
- out_array (np.ndarray, SHAPE): field evaluated at coords of interest
|
|
|
|
"""
|
|
|
|
N = input_array.shape[-1]
|
|
|
|
# Change coords to index
|
|
pos = (coords - jnp.expand_dims(origin, axis=(1,2))) * N / L
|
|
|
|
# NOTE: jax's 'wrap' is the same as scipy's 'grid-wrap'
|
|
if use_jitted:
|
|
def fun_to_vmap(arr):
|
|
return jit_map_coordinates(arr, pos, order)
|
|
else:
|
|
def fun_to_vmap(arr):
|
|
return jax.scipy.ndimage.map_coordinates(arr, pos, order=order, mode='wrap')
|
|
|
|
if len(input_array.shape) == coords.shape[0]:
|
|
out_array = jit_map_coordinates(input_array, pos, order)
|
|
elif len(input_array.shape) == coords.shape[0]+1:
|
|
vmap_fun = jax.vmap(fun_to_vmap, in_axes=[0])
|
|
out_array = vmap_fun(input_array)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return out_array
|
|
|
|
|
|
@jax.jit
|
|
def project_radial(vec: jnp.ndarray, coords: jnp.ndarray, origin: jnp.ndarray) -> jnp.ndarray:
|
|
"""
|
|
Project vectors along the radial direction, given by coords
|
|
Args:
|
|
- vec (jnp.ndarray, shape=(ndim,npoint0,npoint1)): array of vectors to be projected
|
|
- coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at
|
|
- origin (jnp.ndarray, shape=(ndim,)): coordinates of the observer
|
|
|
|
Returns:
|
|
- vr (jnp.ndarray, shape=(npoint0,npoint1)): vec projected along radial direction
|
|
"""
|
|
x = coords - jnp.expand_dims(origin, axis=(1,2))
|
|
r = jnp.sqrt(jnp.sum(x**2, axis=0))
|
|
x = x / jnp.expand_dims(r, axis=0)
|
|
vr = jnp.sum(x * vec, axis=0)
|
|
|
|
return vr
|
|
|
|
|
|
def get_radial_vectors(coord_meas):
|
|
|
|
c = SkyCoord(x=coord_meas[0], y=coord_meas[1], z=coord_meas[2],
|
|
representation_type='cartesian')
|
|
RA = c.spherical.lon.degree
|
|
DEC = c.spherical.lat.degree
|
|
|
|
# Get unit vectors along line of sight
|
|
r_hat = jnp.array(SkyCoord(ra=RA*apu.deg, dec=DEC*apu.deg).cartesian.xyz)
|
|
r_hat = jnp.expand_dims(r_hat.T, axis=0)
|
|
|
|
return r_hat |