diff --git a/src/forwards.py b/src/forwards.py new file mode 100644 index 0000000..9d94e9c --- /dev/null +++ b/src/forwards.py @@ -0,0 +1,101 @@ +import jax +import jax.numpy as jnp +from functools import partial + + +@partial(jax.jit, static_argnums=(2,3)) +def dens2vel_linear(delta, f, Lbox, smooth_R): + """ + Converts the overdensity field, delta, into + the velocity field, v, in units of [???], assuming + the linear relation between the two. + + Args: + :delta (jnp.ndarray): (N,N,N) array of overdensity field + :f (float): The growth rate + :Lbox (float): The box length in Mpc/h + :smooth_R (float): The smoothing scale in Mpc/h + + Returns: + :v (jnp.ndarray): The velocity field defined on the grid + + TO CHECK: UNITS + QUESTION: Why two different filters? + """ + + N = delta.shape[0] + V = Lbox**3 + dV = (Lbox / N) ** 3 + + # Forward fft + delta_k = dV / V * jnp.fft.rfftn(delta) + delta_k = jnp.array([delta_k.real, delta_k.imag]) + + # Fourier mask for Nyquist planes + N_Z = N//2 + 1 + Fourier_mask = jnp.ones((N,N,N_Z)) + Fourier_mask = Fourier_mask.at[:,N_Z,0].set(0.) + Fourier_mask = Fourier_mask.at[:,N_Z,-1].set(0.) + + # Indices for symmetrising + update_index_real_0 = jnp.index_exp[0,N_Z:,:,0] + update_index_real_ny = jnp.index_exp[0,N_Z:,:,N_Z-1] + update_index_imag_0 = jnp.index_exp[1,N_Z:,:,0] + update_index_imag_ny = jnp.index_exp[1,N_Z:,:,N_Z-1] + flip_indices = -jnp.arange(N) + flip_indices = flip_indices.at[N_Z-1].set(-flip_indices[N_Z-1]) + flip_indices = jnp.array(flip_indices.tolist()) + + # Symmetrise + delta_k = Fourier_mask[jnp.newaxis] * delta_k + delta_k = delta_k.at[update_index_real_0].set( + jnp.take(jnp.flip(delta_k[0, 1:(N_Z-1), :, 0], axis=0), + flip_indices, axis=1) + ) + delta_k = delta_k.at[update_index_real_ny].set( + jnp.take(jnp.flip(delta_k[0,1:(N_Z-1),:,N_Z-1], axis=0), + flip_indices, axis=1) + ) + delta_k = delta_k.at[update_index_imag_0].set( + -jnp.take(jnp.flip(delta_k[1,1:(N_Z-1),:,0], axis=0), + flip_indices, axis=1) + ) + + delta_k = delta_k[0] + jnp.array(complex(0, 1)) * delta_k[1] + + # Get k grid + k = 2*jnp.pi*jnp.fft.fftfreq(N, d=Lbox/N) + k_norm = jnp.sqrt(k[:,None,None]**2 + k[None,:,None]**2 + kz_vec[None,None,:]**2) + k_norm = k_norm.at[(k_norm < 1e-10)].set(1e-15) + + # Filter + k_filter = jnp.exp(-0.5 * (k_norm[:,:,:N_Z] * smooth_R) ** 2) + smooth_filter = k_filter + vx = ( + smooth_filter * k_filter + * jnp.array(complex(0, 1)) * 100 * f + * delta_k_complex + * jnp.tile(k[:,None,None],(1,N,N_Z)) + / k_norm**2 + ) + vx = (jnp.fft.irfftn(vx)*V/dV) + + vy = ( + smooth_filter * k_filter + * jnp.array(complex(0, 1)) * 100 * f + * delta_k_complex + * jnp.tile(ky[None,:,None], (N,1,N_Z)) + / k_norm**2 + ) + vy = (jnp.fft.irfftn(vy)*V/dV) + + vz = ( + smooth_filter * k_filter + * jnp.array(complex(0, 1)) * 100 * f + * delta_k_complex + * jnp.tile(kz[None,None,:N_Z], (N,N,1)) + / k_norm**2 + ) + vz = (jnp.fft.irfftn(vz)*V/dV) + + return jnp.array([vx, vy, vz]) \ No newline at end of file diff --git a/src/projection.py b/src/projection.py new file mode 100644 index 0000000..13a361a --- /dev/null +++ b/src/projection.py @@ -0,0 +1,80 @@ +import jax.numpy as jnp +import jax.scipy.ndimage +import jax +from functools import partial + +@partial(jax.jit, static_argnames=['order']) +def jit_map_coordinates(image: jnp.ndarray, coords:jnp.ndarray, order: int) -> jnp.ndarray: + """ + Jitted version of jax.scipy.ndimage.map_coordinates + + Args: + - image (jnp.ndarray) - The input array + - coords (jnp.ndarray) - The coordinates at which image is evaluated. + - order (int): order of interpolation (0 <= order <= 5) + Returns: + - map_coordinates (jnp.ndarray) - The result of transforming the input. The shape of the output is derived from that of coordinates by dropping the first axis. + + """ + return jax.scipy.ndimage.map_coordinates(image, coords, order=order, mode='wrap') + + +@partial(jax.jit, static_argnames=['order','use_jitted']) +def interp_field(input_array: jnp.ndarray, coords: jnp.ndarray, L: float, origin: jnp.ndarray, order: int, use_jitted: bool=False) -> jnp.ndarray: + """ + Interpolate an array on a ND-cubic grid to new coordinates linearly + + Args: + - input_array (jnp.ndarray): array to be interpolated + - coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at + - L (float): length of box + - origin (jnp.ndarray, shape=(ndim,)): position corresponding to index [0,0,...] + - order (int): order of interpolation (0 <= order <= 5) + + Returns: + - out_array (np.ndarray, SHAPE): field evaluated at coords of interest + + """ + + N = input_array.shape[-1] + + # Change coords to index + pos = (coords - jnp.expand_dims(origin, axis=(1,2))) * N / L + + # NOTE: jax's 'wrap' is the same as scipy's 'grid-wrap' + if use_jitted: + def fun_to_vmap(arr): + return jit_map_coordinates(arr, pos, order) + else: + def fun_to_vmap(arr): + return jax.scipy.ndimage.map_coordinates(arr, pos, order=order, mode='wrap') + + if len(input_array.shape) == coords.shape[0]: + out_array = jit_map_coordinates(input_array, pos, order) + elif len(input_array.shape) == coords.shape[0]+1: + vmap_fun = jax.vmap(fun_to_vmap, in_axes=[0]) + out_array = vmap_fun(input_array) + else: + raise NotImplementedError + + return out_array + + +@jax.jit +def project_radial(vec: jnp.ndarray, coords: jnp.ndarray, origin: jnp.ndarray) -> jnp.ndarray: + """ + Project vectors along the radial direction, given by coords + Args: + - vec (jnp.ndarray, shape=(ndim,npoint0,npoint1)): array of vectors to be projected + - coords (jnp.ndarray shape=(ndim,npoint0,npoint1)): coordinates to evaluate at + - origin (jnp.ndarray, shape=(ndim,)): coordinates of the observer + + Returns: + - vr (jnp.ndarray, shape=(npoint0,npoint1)): vec projected along radial direction + """ + x = coords - jnp.expand_dims(origin, axis=(1,2)) + r = jnp.sqrt(jnp.sum(x**2, axis=0)) + x = x / jnp.expand_dims(r, axis=0) + vr = jnp.sum(x * vec, axis=0) + + return vr \ No newline at end of file diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..52f9121 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,68 @@ +import aquila_borg as borg +import configparser +import os + +# Output stream management +cons = borg.console() +myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x)) + + +def get_cosmopar(ini_file): + """ + Extract cosmological parameters from an ini file + + Args: + :ini_file (str): Path to the ini file + + Returns: + :cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters + """ + + config = configparser.ConfigParser() + config.read(ini_file) + + cpar = borg.cosmo.CosmologicalParameters() + cpar.default() + cpar.fnl = float(config['cosmology']['fnl']) + cpar.omega_k = float(config['cosmology']['omega_k']) + cpar.omega_m = float(config['cosmology']['omega_m']) + cpar.omega_b = float(config['cosmology']['omega_b']) + cpar.omega_q = float(config['cosmology']['omega_q']) + cpar.h = float(config['cosmology']['h100']) + cpar.sigma8 = float(config['cosmology']['sigma8']) + cpar.n_s = float(config['cosmology']['n_s']) + cpar.w = float(config['cosmology']['w']) + cpar.wprime = float(config['cosmology']['wprime']) + + return cpar + + +def get_action(): + """ + Find mode which BORG is currently being run in + + Returns: + :last_line (str): The mode in which BORG is being run + """ + + # Find the last line of the file + with open('trace_file.txt', 'rb') as file: + file.seek(-2, os.SEEK_END) # Go to the second last byte in the file. + last_line = '' + while len(last_line) == 0: + while file.read(1) != b'\n': # Keep moving back until you find a newline. + file.seek(-2, os.SEEK_CUR) + last_line = file.readline().decode().strip('\n') + + # Check that the this is the line was want and find action + cmd = "hades_python" + assert cmd in last_line + idx = last_line.index(cmd) + last_line = last_line[idx+len(cmd):] + while last_line[0] == ' ': + last_line = last_line[1:] + idx = last_line.index(' ') + last_line = last_line[:idx].upper() + myprint(f'Running BORG mode: {last_line}') + + return last_line \ No newline at end of file