dens2vel, projection and cosmological parameter loading

This commit is contained in:
Deaglan Bartlett 2024-04-23 14:27:24 +02:00
parent 9828bd6efd
commit 874d26e38c
3 changed files with 249 additions and 0 deletions

101
src/forwards.py Normal file
View file

@ -0,0 +1,101 @@
import jax
import jax.numpy as jnp
from functools import partial
@partial(jax.jit, static_argnums=(2,3))
def dens2vel_linear(delta, f, Lbox, smooth_R):
"""
Converts the overdensity field, delta, into
the velocity field, v, in units of [???], assuming
the linear relation between the two.
Args:
:delta (jnp.ndarray): (N,N,N) array of overdensity field
:f (float): The growth rate
:Lbox (float): The box length in Mpc/h
:smooth_R (float): The smoothing scale in Mpc/h
Returns:
:v (jnp.ndarray): The velocity field defined on the grid
TO CHECK: UNITS
QUESTION: Why two different filters?
"""
N = delta.shape[0]
V = Lbox**3
dV = (Lbox / N) ** 3
# Forward fft
delta_k = dV / V * jnp.fft.rfftn(delta)
delta_k = jnp.array([delta_k.real, delta_k.imag])
# Fourier mask for Nyquist planes
N_Z = N//2 + 1
Fourier_mask = jnp.ones((N,N,N_Z))
Fourier_mask = Fourier_mask.at[:,N_Z,0].set(0.)
Fourier_mask = Fourier_mask.at[:,N_Z,-1].set(0.)
# Indices for symmetrising
update_index_real_0 = jnp.index_exp[0,N_Z:,:,0]
update_index_real_ny = jnp.index_exp[0,N_Z:,:,N_Z-1]
update_index_imag_0 = jnp.index_exp[1,N_Z:,:,0]
update_index_imag_ny = jnp.index_exp[1,N_Z:,:,N_Z-1]
flip_indices = -jnp.arange(N)
flip_indices = flip_indices.at[N_Z-1].set(-flip_indices[N_Z-1])
flip_indices = jnp.array(flip_indices.tolist())
# Symmetrise
delta_k = Fourier_mask[jnp.newaxis] * delta_k
delta_k = delta_k.at[update_index_real_0].set(
jnp.take(jnp.flip(delta_k[0, 1:(N_Z-1), :, 0], axis=0),
flip_indices, axis=1)
)
delta_k = delta_k.at[update_index_real_ny].set(
jnp.take(jnp.flip(delta_k[0,1:(N_Z-1),:,N_Z-1], axis=0),
flip_indices, axis=1)
)
delta_k = delta_k.at[update_index_imag_0].set(
-jnp.take(jnp.flip(delta_k[1,1:(N_Z-1),:,0], axis=0),
flip_indices, axis=1)
)
delta_k = delta_k[0] + jnp.array(complex(0, 1)) * delta_k[1]
# Get k grid
k = 2*jnp.pi*jnp.fft.fftfreq(N, d=Lbox/N)
k_norm = jnp.sqrt(k[:,None,None]**2 + k[None,:,None]**2 + kz_vec[None,None,:]**2)
k_norm = k_norm.at[(k_norm < 1e-10)].set(1e-15)
# Filter
k_filter = jnp.exp(-0.5 * (k_norm[:,:,:N_Z] * smooth_R) ** 2)
smooth_filter = k_filter
vx = (
smooth_filter * k_filter
* jnp.array(complex(0, 1)) * 100 * f
* delta_k_complex
* jnp.tile(k[:,None,None],(1,N,N_Z))
/ k_norm**2
)
vx = (jnp.fft.irfftn(vx)*V/dV)
vy = (
smooth_filter * k_filter
* jnp.array(complex(0, 1)) * 100 * f
* delta_k_complex
* jnp.tile(ky[None,:,None], (N,1,N_Z))
/ k_norm**2
)
vy = (jnp.fft.irfftn(vy)*V/dV)
vz = (
smooth_filter * k_filter
* jnp.array(complex(0, 1)) * 100 * f
* delta_k_complex
* jnp.tile(kz[None,None,:N_Z], (N,N,1))
/ k_norm**2
)
vz = (jnp.fft.irfftn(vz)*V/dV)
return jnp.array([vx, vy, vz])

80
src/projection.py Normal file
View file

@ -0,0 +1,80 @@
import jax.numpy as jnp
import jax.scipy.ndimage
import jax
from functools import partial
@partial(jax.jit, static_argnames=['order'])
def jit_map_coordinates(image: jnp.ndarray, coords:jnp.ndarray, order: int) -> jnp.ndarray:
"""
Jitted version of jax.scipy.ndimage.map_coordinates
Args:
- image (jnp.ndarray) - The input array
- coords (jnp.ndarray) - The coordinates at which image is evaluated.
- order (int): order of interpolation (0 <= order <= 5)
Returns:
- map_coordinates (jnp.ndarray) - The result of transforming the input. The shape of the output is derived from that of coordinates by dropping the first axis.
"""
return jax.scipy.ndimage.map_coordinates(image, coords, order=order, mode='wrap')
@partial(jax.jit, static_argnames=['order','use_jitted'])
def interp_field(input_array: jnp.ndarray, coords: jnp.ndarray, L: float, origin: jnp.ndarray, order: int, use_jitted: bool=False) -> jnp.ndarray:
"""
Interpolate an array on a ND-cubic grid to new coordinates linearly
Args:
- input_array (jnp.ndarray): array to be interpolated
- coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at
- L (float): length of box
- origin (jnp.ndarray, shape=(ndim,)): position corresponding to index [0,0,...]
- order (int): order of interpolation (0 <= order <= 5)
Returns:
- out_array (np.ndarray, SHAPE): field evaluated at coords of interest
"""
N = input_array.shape[-1]
# Change coords to index
pos = (coords - jnp.expand_dims(origin, axis=(1,2))) * N / L
# NOTE: jax's 'wrap' is the same as scipy's 'grid-wrap'
if use_jitted:
def fun_to_vmap(arr):
return jit_map_coordinates(arr, pos, order)
else:
def fun_to_vmap(arr):
return jax.scipy.ndimage.map_coordinates(arr, pos, order=order, mode='wrap')
if len(input_array.shape) == coords.shape[0]:
out_array = jit_map_coordinates(input_array, pos, order)
elif len(input_array.shape) == coords.shape[0]+1:
vmap_fun = jax.vmap(fun_to_vmap, in_axes=[0])
out_array = vmap_fun(input_array)
else:
raise NotImplementedError
return out_array
@jax.jit
def project_radial(vec: jnp.ndarray, coords: jnp.ndarray, origin: jnp.ndarray) -> jnp.ndarray:
"""
Project vectors along the radial direction, given by coords
Args:
- vec (jnp.ndarray, shape=(ndim,npoint0,npoint1)): array of vectors to be projected
- coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at
- origin (jnp.ndarray, shape=(ndim,)): coordinates of the observer
Returns:
- vr (jnp.ndarray, shape=(npoint0,npoint1)): vec projected along radial direction
"""
x = coords - jnp.expand_dims(origin, axis=(1,2))
r = jnp.sqrt(jnp.sum(x**2, axis=0))
x = x / jnp.expand_dims(r, axis=0)
vr = jnp.sum(x * vec, axis=0)
return vr

68
src/utils.py Normal file
View file

@ -0,0 +1,68 @@
import aquila_borg as borg
import configparser
import os
# Output stream management
cons = borg.console()
myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x))
def get_cosmopar(ini_file):
"""
Extract cosmological parameters from an ini file
Args:
:ini_file (str): Path to the ini file
Returns:
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters
"""
config = configparser.ConfigParser()
config.read(ini_file)
cpar = borg.cosmo.CosmologicalParameters()
cpar.default()
cpar.fnl = float(config['cosmology']['fnl'])
cpar.omega_k = float(config['cosmology']['omega_k'])
cpar.omega_m = float(config['cosmology']['omega_m'])
cpar.omega_b = float(config['cosmology']['omega_b'])
cpar.omega_q = float(config['cosmology']['omega_q'])
cpar.h = float(config['cosmology']['h100'])
cpar.sigma8 = float(config['cosmology']['sigma8'])
cpar.n_s = float(config['cosmology']['n_s'])
cpar.w = float(config['cosmology']['w'])
cpar.wprime = float(config['cosmology']['wprime'])
return cpar
def get_action():
"""
Find mode which BORG is currently being run in
Returns:
:last_line (str): The mode in which BORG is being run
"""
# Find the last line of the file
with open('trace_file.txt', 'rb') as file:
file.seek(-2, os.SEEK_END) # Go to the second last byte in the file.
last_line = ''
while len(last_line) == 0:
while file.read(1) != b'\n': # Keep moving back until you find a newline.
file.seek(-2, os.SEEK_CUR)
last_line = file.readline().decode().strip('\n')
# Check that the this is the line was want and find action
cmd = "hades_python"
assert cmd in last_line
idx = last_line.index(cmd)
last_line = last_line[idx+len(cmd):]
while last_line[0] == ' ':
last_line = last_line[1:]
idx = last_line.index(' ')
last_line = last_line[:idx].upper()
myprint(f'Running BORG mode: {last_line}')
return last_line