dens2vel, projection and cosmological parameter loading
This commit is contained in:
parent
9828bd6efd
commit
874d26e38c
3 changed files with 249 additions and 0 deletions
101
src/forwards.py
Normal file
101
src/forwards.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
import jax
|
||||
import jax.numpy as jnp
|
||||
from functools import partial
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(2,3))
|
||||
def dens2vel_linear(delta, f, Lbox, smooth_R):
|
||||
"""
|
||||
Converts the overdensity field, delta, into
|
||||
the velocity field, v, in units of [???], assuming
|
||||
the linear relation between the two.
|
||||
|
||||
Args:
|
||||
:delta (jnp.ndarray): (N,N,N) array of overdensity field
|
||||
:f (float): The growth rate
|
||||
:Lbox (float): The box length in Mpc/h
|
||||
:smooth_R (float): The smoothing scale in Mpc/h
|
||||
|
||||
Returns:
|
||||
:v (jnp.ndarray): The velocity field defined on the grid
|
||||
|
||||
TO CHECK: UNITS
|
||||
QUESTION: Why two different filters?
|
||||
"""
|
||||
|
||||
N = delta.shape[0]
|
||||
V = Lbox**3
|
||||
dV = (Lbox / N) ** 3
|
||||
|
||||
# Forward fft
|
||||
delta_k = dV / V * jnp.fft.rfftn(delta)
|
||||
delta_k = jnp.array([delta_k.real, delta_k.imag])
|
||||
|
||||
# Fourier mask for Nyquist planes
|
||||
N_Z = N//2 + 1
|
||||
Fourier_mask = jnp.ones((N,N,N_Z))
|
||||
Fourier_mask = Fourier_mask.at[:,N_Z,0].set(0.)
|
||||
Fourier_mask = Fourier_mask.at[:,N_Z,-1].set(0.)
|
||||
|
||||
# Indices for symmetrising
|
||||
update_index_real_0 = jnp.index_exp[0,N_Z:,:,0]
|
||||
update_index_real_ny = jnp.index_exp[0,N_Z:,:,N_Z-1]
|
||||
update_index_imag_0 = jnp.index_exp[1,N_Z:,:,0]
|
||||
update_index_imag_ny = jnp.index_exp[1,N_Z:,:,N_Z-1]
|
||||
flip_indices = -jnp.arange(N)
|
||||
flip_indices = flip_indices.at[N_Z-1].set(-flip_indices[N_Z-1])
|
||||
flip_indices = jnp.array(flip_indices.tolist())
|
||||
|
||||
# Symmetrise
|
||||
delta_k = Fourier_mask[jnp.newaxis] * delta_k
|
||||
delta_k = delta_k.at[update_index_real_0].set(
|
||||
jnp.take(jnp.flip(delta_k[0, 1:(N_Z-1), :, 0], axis=0),
|
||||
flip_indices, axis=1)
|
||||
)
|
||||
delta_k = delta_k.at[update_index_real_ny].set(
|
||||
jnp.take(jnp.flip(delta_k[0,1:(N_Z-1),:,N_Z-1], axis=0),
|
||||
flip_indices, axis=1)
|
||||
)
|
||||
delta_k = delta_k.at[update_index_imag_0].set(
|
||||
-jnp.take(jnp.flip(delta_k[1,1:(N_Z-1),:,0], axis=0),
|
||||
flip_indices, axis=1)
|
||||
)
|
||||
|
||||
delta_k = delta_k[0] + jnp.array(complex(0, 1)) * delta_k[1]
|
||||
|
||||
# Get k grid
|
||||
k = 2*jnp.pi*jnp.fft.fftfreq(N, d=Lbox/N)
|
||||
k_norm = jnp.sqrt(k[:,None,None]**2 + k[None,:,None]**2 + kz_vec[None,None,:]**2)
|
||||
k_norm = k_norm.at[(k_norm < 1e-10)].set(1e-15)
|
||||
|
||||
# Filter
|
||||
k_filter = jnp.exp(-0.5 * (k_norm[:,:,:N_Z] * smooth_R) ** 2)
|
||||
smooth_filter = k_filter
|
||||
vx = (
|
||||
smooth_filter * k_filter
|
||||
* jnp.array(complex(0, 1)) * 100 * f
|
||||
* delta_k_complex
|
||||
* jnp.tile(k[:,None,None],(1,N,N_Z))
|
||||
/ k_norm**2
|
||||
)
|
||||
vx = (jnp.fft.irfftn(vx)*V/dV)
|
||||
|
||||
vy = (
|
||||
smooth_filter * k_filter
|
||||
* jnp.array(complex(0, 1)) * 100 * f
|
||||
* delta_k_complex
|
||||
* jnp.tile(ky[None,:,None], (N,1,N_Z))
|
||||
/ k_norm**2
|
||||
)
|
||||
vy = (jnp.fft.irfftn(vy)*V/dV)
|
||||
|
||||
vz = (
|
||||
smooth_filter * k_filter
|
||||
* jnp.array(complex(0, 1)) * 100 * f
|
||||
* delta_k_complex
|
||||
* jnp.tile(kz[None,None,:N_Z], (N,N,1))
|
||||
/ k_norm**2
|
||||
)
|
||||
vz = (jnp.fft.irfftn(vz)*V/dV)
|
||||
|
||||
return jnp.array([vx, vy, vz])
|
80
src/projection.py
Normal file
80
src/projection.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import jax.numpy as jnp
|
||||
import jax.scipy.ndimage
|
||||
import jax
|
||||
from functools import partial
|
||||
|
||||
@partial(jax.jit, static_argnames=['order'])
|
||||
def jit_map_coordinates(image: jnp.ndarray, coords:jnp.ndarray, order: int) -> jnp.ndarray:
|
||||
"""
|
||||
Jitted version of jax.scipy.ndimage.map_coordinates
|
||||
|
||||
Args:
|
||||
- image (jnp.ndarray) - The input array
|
||||
- coords (jnp.ndarray) - The coordinates at which image is evaluated.
|
||||
- order (int): order of interpolation (0 <= order <= 5)
|
||||
Returns:
|
||||
- map_coordinates (jnp.ndarray) - The result of transforming the input. The shape of the output is derived from that of coordinates by dropping the first axis.
|
||||
|
||||
"""
|
||||
return jax.scipy.ndimage.map_coordinates(image, coords, order=order, mode='wrap')
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnames=['order','use_jitted'])
|
||||
def interp_field(input_array: jnp.ndarray, coords: jnp.ndarray, L: float, origin: jnp.ndarray, order: int, use_jitted: bool=False) -> jnp.ndarray:
|
||||
"""
|
||||
Interpolate an array on a ND-cubic grid to new coordinates linearly
|
||||
|
||||
Args:
|
||||
- input_array (jnp.ndarray): array to be interpolated
|
||||
- coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at
|
||||
- L (float): length of box
|
||||
- origin (jnp.ndarray, shape=(ndim,)): position corresponding to index [0,0,...]
|
||||
- order (int): order of interpolation (0 <= order <= 5)
|
||||
|
||||
Returns:
|
||||
- out_array (np.ndarray, SHAPE): field evaluated at coords of interest
|
||||
|
||||
"""
|
||||
|
||||
N = input_array.shape[-1]
|
||||
|
||||
# Change coords to index
|
||||
pos = (coords - jnp.expand_dims(origin, axis=(1,2))) * N / L
|
||||
|
||||
# NOTE: jax's 'wrap' is the same as scipy's 'grid-wrap'
|
||||
if use_jitted:
|
||||
def fun_to_vmap(arr):
|
||||
return jit_map_coordinates(arr, pos, order)
|
||||
else:
|
||||
def fun_to_vmap(arr):
|
||||
return jax.scipy.ndimage.map_coordinates(arr, pos, order=order, mode='wrap')
|
||||
|
||||
if len(input_array.shape) == coords.shape[0]:
|
||||
out_array = jit_map_coordinates(input_array, pos, order)
|
||||
elif len(input_array.shape) == coords.shape[0]+1:
|
||||
vmap_fun = jax.vmap(fun_to_vmap, in_axes=[0])
|
||||
out_array = vmap_fun(input_array)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return out_array
|
||||
|
||||
|
||||
@jax.jit
|
||||
def project_radial(vec: jnp.ndarray, coords: jnp.ndarray, origin: jnp.ndarray) -> jnp.ndarray:
|
||||
"""
|
||||
Project vectors along the radial direction, given by coords
|
||||
Args:
|
||||
- vec (jnp.ndarray, shape=(ndim,npoint0,npoint1)): array of vectors to be projected
|
||||
- coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at
|
||||
- origin (jnp.ndarray, shape=(ndim,)): coordinates of the observer
|
||||
|
||||
Returns:
|
||||
- vr (jnp.ndarray, shape=(npoint0,npoint1)): vec projected along radial direction
|
||||
"""
|
||||
x = coords - jnp.expand_dims(origin, axis=(1,2))
|
||||
r = jnp.sqrt(jnp.sum(x**2, axis=0))
|
||||
x = x / jnp.expand_dims(r, axis=0)
|
||||
vr = jnp.sum(x * vec, axis=0)
|
||||
|
||||
return vr
|
68
src/utils.py
Normal file
68
src/utils.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import aquila_borg as borg
|
||||
import configparser
|
||||
import os
|
||||
|
||||
# Output stream management
|
||||
cons = borg.console()
|
||||
myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x))
|
||||
|
||||
|
||||
def get_cosmopar(ini_file):
|
||||
"""
|
||||
Extract cosmological parameters from an ini file
|
||||
|
||||
Args:
|
||||
:ini_file (str): Path to the ini file
|
||||
|
||||
Returns:
|
||||
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters
|
||||
"""
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(ini_file)
|
||||
|
||||
cpar = borg.cosmo.CosmologicalParameters()
|
||||
cpar.default()
|
||||
cpar.fnl = float(config['cosmology']['fnl'])
|
||||
cpar.omega_k = float(config['cosmology']['omega_k'])
|
||||
cpar.omega_m = float(config['cosmology']['omega_m'])
|
||||
cpar.omega_b = float(config['cosmology']['omega_b'])
|
||||
cpar.omega_q = float(config['cosmology']['omega_q'])
|
||||
cpar.h = float(config['cosmology']['h100'])
|
||||
cpar.sigma8 = float(config['cosmology']['sigma8'])
|
||||
cpar.n_s = float(config['cosmology']['n_s'])
|
||||
cpar.w = float(config['cosmology']['w'])
|
||||
cpar.wprime = float(config['cosmology']['wprime'])
|
||||
|
||||
return cpar
|
||||
|
||||
|
||||
def get_action():
|
||||
"""
|
||||
Find mode which BORG is currently being run in
|
||||
|
||||
Returns:
|
||||
:last_line (str): The mode in which BORG is being run
|
||||
"""
|
||||
|
||||
# Find the last line of the file
|
||||
with open('trace_file.txt', 'rb') as file:
|
||||
file.seek(-2, os.SEEK_END) # Go to the second last byte in the file.
|
||||
last_line = ''
|
||||
while len(last_line) == 0:
|
||||
while file.read(1) != b'\n': # Keep moving back until you find a newline.
|
||||
file.seek(-2, os.SEEK_CUR)
|
||||
last_line = file.readline().decode().strip('\n')
|
||||
|
||||
# Check that the this is the line was want and find action
|
||||
cmd = "hades_python"
|
||||
assert cmd in last_line
|
||||
idx = last_line.index(cmd)
|
||||
last_line = last_line[idx+len(cmd):]
|
||||
while last_line[0] == ' ':
|
||||
last_line = last_line[1:]
|
||||
idx = last_line.index(' ')
|
||||
last_line = last_line[:idx].upper()
|
||||
myprint(f'Running BORG mode: {last_line}')
|
||||
|
||||
return last_line
|
Loading…
Add table
Reference in a new issue