Basic structure for forwards and likelihood
This commit is contained in:
parent
80c0224e71
commit
21e8fd82ec
8 changed files with 1625 additions and 0 deletions
82
conf/basic_ini.ini
Normal file
82
conf/basic_ini.ini
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
[system]
|
||||||
|
console_output = borg_log
|
||||||
|
VERBOSE_LEVEL = 2
|
||||||
|
N0 = 128
|
||||||
|
N1 = 128
|
||||||
|
N2 = 128
|
||||||
|
L0 = 500.0
|
||||||
|
L1 = 500.0
|
||||||
|
L2 = 500.0
|
||||||
|
corner0 = -250.0
|
||||||
|
corner1 = -250.0
|
||||||
|
corner2 = -250.0
|
||||||
|
NUM_MODES = 100
|
||||||
|
test_mode = true
|
||||||
|
seed_cpower = true
|
||||||
|
|
||||||
|
[block_loop]
|
||||||
|
hades_sampler_blocked = false
|
||||||
|
bias_sampler_blocked= true
|
||||||
|
gauss_sigma_sampler_blocked = false
|
||||||
|
ares_heat = 1.0
|
||||||
|
|
||||||
|
[mcmc]
|
||||||
|
number_to_generate = 10
|
||||||
|
warmup_model = 3
|
||||||
|
warmup_cosmo = 7
|
||||||
|
random_ic = false
|
||||||
|
init_random_scaling = 0.1
|
||||||
|
bignum = 1e300
|
||||||
|
|
||||||
|
[hades]
|
||||||
|
algorithm = HMC
|
||||||
|
max_epsilon = 0.01
|
||||||
|
max_timesteps = 50
|
||||||
|
mixing = 1
|
||||||
|
|
||||||
|
[model]
|
||||||
|
gravity = lpt
|
||||||
|
logfR0 = -5.0
|
||||||
|
af = 1.0
|
||||||
|
ai = 0.05
|
||||||
|
|
||||||
|
[prior]
|
||||||
|
omega_m = [0.1, 0.8]
|
||||||
|
sigma8 = [0.1, 1.5]
|
||||||
|
muA = [0.5, 1.5]
|
||||||
|
alpha = [0.0, 10.0]
|
||||||
|
sig_v = [50.0, 200.0]
|
||||||
|
bulk_flow = [-200.0, 200.0]
|
||||||
|
|
||||||
|
[cosmology]
|
||||||
|
omega_r = 0
|
||||||
|
fnl = 0
|
||||||
|
omega_k = 0
|
||||||
|
omega_m = 0.315
|
||||||
|
omega_b = 0.049
|
||||||
|
omega_q = 0.685
|
||||||
|
h100 = 0.68
|
||||||
|
sigma8 = 0.81
|
||||||
|
n_s = 0.97
|
||||||
|
w = -1
|
||||||
|
wprime = 0
|
||||||
|
beta = 1.5
|
||||||
|
z0 = 0
|
||||||
|
|
||||||
|
[emulator]
|
||||||
|
use_emulator = True
|
||||||
|
model_weights_path = /home/bartlett/mgborg_emulator/weights/emulator_weights.file_type
|
||||||
|
architecture = StyledVNet
|
||||||
|
use_float64 = False
|
||||||
|
use_pad_and_NN = True
|
||||||
|
requires_grad = False
|
||||||
|
|
||||||
|
[run]
|
||||||
|
run_type = mock
|
||||||
|
NCAT = 0
|
||||||
|
|
||||||
|
[mock]
|
||||||
|
seed = 123
|
||||||
|
|
||||||
|
[python]
|
||||||
|
likelihood_path = /home/bartlett/mgborg_emulator/mgborg_emulator/likelihood.py
|
470
mgborg_emulator/forwards.py
Normal file
470
mgborg_emulator/forwards.py
Normal file
|
@ -0,0 +1,470 @@
|
||||||
|
import aquila_borg as borg
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# Emulator imports
|
||||||
|
import torch
|
||||||
|
# from .model_architecture.cosmology import *
|
||||||
|
from networks import StyledVNet, StyledVNet_distill
|
||||||
|
|
||||||
|
# JAX set-up
|
||||||
|
import jax
|
||||||
|
from functools import partial
|
||||||
|
import jax.numpy as jnp
|
||||||
|
from jax import vjp
|
||||||
|
from jax.config import config as jax_config
|
||||||
|
jax_config.update("jax_enable_x64", True)
|
||||||
|
|
||||||
|
from utils import myprint
|
||||||
|
|
||||||
|
class Emulator(borg.forward.BaseForwardModel):
|
||||||
|
|
||||||
|
def __init__(self, box, prev_module, NN, Om, upan, use_float64, q, requires_grad, cuda_avail, debug=False):
|
||||||
|
super().__init__(box, box)
|
||||||
|
|
||||||
|
# Set-up and cosmo
|
||||||
|
self.box = box
|
||||||
|
self.q = q
|
||||||
|
self.Om = Om
|
||||||
|
self.debug = debug
|
||||||
|
|
||||||
|
# LPT and EMU
|
||||||
|
self.prev_module = prev_module
|
||||||
|
self.NN = NN
|
||||||
|
|
||||||
|
# Since we won't return an adjoint for the density, but for positions, we need to do:
|
||||||
|
self.prev_module.accumulateAdjoint(True)
|
||||||
|
|
||||||
|
# Settings
|
||||||
|
self.requires_grad = requires_grad
|
||||||
|
self.use_pad_and_NN = upan
|
||||||
|
self.use_float64 = use_float64
|
||||||
|
self.cuda_avail = cuda_avail
|
||||||
|
if self.use_float64:
|
||||||
|
self.dtype = np.float64
|
||||||
|
else:
|
||||||
|
self.dtype = np.float32
|
||||||
|
|
||||||
|
def requires_grad_true(self):
|
||||||
|
self.requires_grad = True
|
||||||
|
|
||||||
|
def requires_grad_false(self):
|
||||||
|
self.requires_grad = False
|
||||||
|
|
||||||
|
# IO "preferences"
|
||||||
|
def getPreferredInput(self):
|
||||||
|
return borg.forward.PREFERRED_REAL
|
||||||
|
|
||||||
|
def getPreferredOutput(self):
|
||||||
|
return borg.forward.PREFERRED_REAL
|
||||||
|
|
||||||
|
# Forward part
|
||||||
|
def forwardModel_v2_impl(self, input_array):
|
||||||
|
global requires_grad
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
myprint(f'Requires grad = {self.requires_grad}')
|
||||||
|
|
||||||
|
init_time = time.time()
|
||||||
|
|
||||||
|
# Step 0 - Extract particle positions
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
pos = np.zeros((self.prev_module.getNumberOfParticles(), 3)) #output shape: (N^3, 3)
|
||||||
|
self.prev_module.getParticlePositions(pos)
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 0 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 1 - find displacements
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
disp = pos - self.q
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 1 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 2 - correct for particles that moved over the periodic boundary
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
disp_temp = correct_displacement_over_periodic_boundaries(disp,L=self.box.L[0],max_disp_1d=self.box.L[0]//2)
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 2 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 3 - reshaping initial pos and displacement
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
# not sure why order='C' is working here... not sure if it matters... could change it below
|
||||||
|
q_reshaped = np.reshape(self.q.T, (3,self.box.N[0],self.box.N[0],self.box.N[0]), order='C') #output shape: (3, N, N, N)
|
||||||
|
dis_in = np.reshape(disp_temp.T, (3,self.box.N[0],self.box.N[0],self.box.N[0]), order='C') #output shape: (3, N, N, N)
|
||||||
|
if self.debug:
|
||||||
|
self.dis_in = dis_in
|
||||||
|
myprint("Step 3 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 4 - normalize
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
dis(dis_in)
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 4 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
if self.use_pad_and_NN:
|
||||||
|
if self.debug:
|
||||||
|
myprint('Using padding and NN.')
|
||||||
|
|
||||||
|
if not self.use_float64:
|
||||||
|
dis_in = dis_in.astype(np.float32)
|
||||||
|
|
||||||
|
# Step 5 - padding to (3,N+48*2,N+48*2,N+48*2)
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# for jax adjoint
|
||||||
|
if self.debug:
|
||||||
|
myprint(f'in_pad shape = {dis_in.shape}')
|
||||||
|
dis_in_padded, self.ag_pad = vjp(self.padding, dis_in) #output shape: (3, N+96, N+96, N+96)
|
||||||
|
if self.debug:
|
||||||
|
myprint(f'out_pad shape = {np.shape(np.asarray(dis_in_padded))}')
|
||||||
|
myprint("Step 5 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 6 - turn into a pytorch tensor (unsquueze because batch = 1)
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
if self.use_float64:
|
||||||
|
self.x = torch.unsqueeze(torch.tensor(np.asarray(dis_in_padded),dtype=torch.float64, requires_grad=self.requires_grad),dim=0) #output shape: (1, 3, N+96, N+96, N+96)
|
||||||
|
else:
|
||||||
|
self.x = torch.unsqueeze(torch.tensor(np.asarray(dis_in_padded),dtype=torch.float32, requires_grad=self.requires_grad),dim=0) #output shape: (1, 3, N+96, N+96, N+96)
|
||||||
|
if self.cuda_avail:
|
||||||
|
self.x = self.x.cuda()
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 6 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 7 - Pipe through emulator
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
if self.requires_grad:
|
||||||
|
self.y = self.NN(self.x,self.Om) #output shape: (1, 3, N, N, N)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.y = self.NN(self.x,self.Om) #output shape: (1, 3, N, N, N)
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 7 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 8 - N-body sim displacement
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
dis_out = torch.squeeze(self.y).detach().cpu().numpy()
|
||||||
|
|
||||||
|
else:
|
||||||
|
myprint('Skipping padding and NN.')
|
||||||
|
dis_out = dis_in
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 8 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 9 - undo the normalization
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
dis(dis_out,undo=True)
|
||||||
|
if self.debug:
|
||||||
|
self.dis_out = dis_out
|
||||||
|
myprint("Step 9 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 10 - convert displacement into positions
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
pos = dis_out + q_reshaped
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 10 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 11 - make sure everything within the box
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
pos[pos>=self.box.L[0]] -= self.box.L[0]
|
||||||
|
pos[pos<0] += self.box.L[0]
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 11 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 12 - reshape positions
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
self.pos_out = pos.reshape(3,self.box.N[0]**3,order='C').T #output shape: (N^3, 3)
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 12 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# Step 13 - CIC
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
#self.dens_out = cic_analytical(self.pos_out[:, 0], self.pos_out[:, 1], self.pos_out[:, 2], *self.box.N + self.box.L)
|
||||||
|
self.dens_out, _ = vjp(self.jax_cic, self.pos_out[:, 0], self.pos_out[:, 1], self.pos_out[:, 2])
|
||||||
|
|
||||||
|
if self.debug:
|
||||||
|
myprint("Step 13 of forward pass took %s seconds" % (time.time() - start_time))
|
||||||
|
myprint('Finished forward pass of emu in %s seconds' % (time.time() - init_time))
|
||||||
|
|
||||||
|
def getParticlePositions(self,output_array):
|
||||||
|
output_array[:] = self.pos_out
|
||||||
|
|
||||||
|
def getDensityFinal_impl(self, output_array):
|
||||||
|
output_array[:] = self.dens_out
|
||||||
|
|
||||||
|
# Adjoint part
|
||||||
|
def adjointModel_v2_impl(self, input_ag):
|
||||||
|
|
||||||
|
init_time = time.time()
|
||||||
|
# reverse step 13 (CIC)
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
ag = np.asarray(self.cic_analytical_grad(input_ag,self.pos_out[:,0],self.pos_out[:,1],self.pos_out[:,2],128,128,128,250,250,250))
|
||||||
|
ag = np.copy(ag)
|
||||||
|
if self.debug:
|
||||||
|
myprint("Reverse step 13 took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# reverse step 11
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
ag = np.reshape(ag.T, (3,*self.box.N), order='C') #changed recently
|
||||||
|
if self.debug:
|
||||||
|
myprint("Reverse step 11 took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# reverse step 9
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
dis(ag,undo=False)
|
||||||
|
if self.debug:
|
||||||
|
myprint("Reverse step 9 took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
if self.use_pad_and_NN:
|
||||||
|
# reverse step 8
|
||||||
|
if self.debug:
|
||||||
|
myprint('adjoint... Using padding and NN.')
|
||||||
|
start_time = time.time()
|
||||||
|
if self.use_float64:
|
||||||
|
ag = torch.unsqueeze(torch.tensor(ag,dtype=torch.float64),dim=0)
|
||||||
|
else:
|
||||||
|
ag = torch.unsqueeze(torch.tensor(ag,dtype=torch.float32),dim=0)
|
||||||
|
if self.cuda_avail:
|
||||||
|
ag = ag.cuda()
|
||||||
|
if self.debug:
|
||||||
|
myprint("Reverse step 8 took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# reverse step 7
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
ag = torch.autograd.grad(self.y, self.x, grad_outputs=ag, retain_graph=False)[0]
|
||||||
|
if self.debug:
|
||||||
|
myprint("Reverse step 7 took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# reverse step 6
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
ag = torch.squeeze(ag).detach().cpu().numpy()
|
||||||
|
#ag = torch.squeeze(ag).detach().numpy()
|
||||||
|
if self.debug:
|
||||||
|
myprint("Reverse step 6 took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# reverse step 5
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
ag = np.asarray(self.ag_pad(ag))[0] #not sure why adjoint outputs shape (1,3,128,128,128)
|
||||||
|
if self.debug:
|
||||||
|
myprint("Reverse step 5 took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
else:
|
||||||
|
myprint('adjoint... Skipping padding and NN.')
|
||||||
|
|
||||||
|
# reverse step 4
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
dis(ag,undo=True)
|
||||||
|
if self.debug:
|
||||||
|
myprint("Reverse step 4 took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# reverse step 3
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
self.ag_pos = ag.reshape(3,self.box.N[0]**3,order='C').T
|
||||||
|
if self.debug:
|
||||||
|
myprint("Reverse step 3 took %s seconds" % (time.time() - start_time))
|
||||||
|
|
||||||
|
# memory
|
||||||
|
if self.debug:
|
||||||
|
start_time = time.time()
|
||||||
|
self.ag_pos = np.copy(self.ag_pos,order='C')
|
||||||
|
if self.debug:
|
||||||
|
myprint("Fixing memory issue took %s seconds" % (time.time() - start_time))
|
||||||
|
myprint('Finished backward pass of emu in %s seconds' % (time.time() - init_time))
|
||||||
|
|
||||||
|
def getAdjointModel_impl(self, output_ag):
|
||||||
|
# Set adjoint gradient wrt density field to zero
|
||||||
|
output_ag[:] = 0
|
||||||
|
# and instead specify the adjoint gradient wrt particle positions
|
||||||
|
self.prev_module.adjointModelParticles(ag_pos=np.array(self.ag_pos, dtype=self.dtype), ag_vel=np.zeros_like(self.ag_pos, dtype=self.dtype))
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
|
def padding(self,x):
|
||||||
|
return jnp.pad(x,((0,0),(48,48),(48,48),(48,48)),'wrap')
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(0,))
|
||||||
|
def jax_cic(self, x, y, z):
|
||||||
|
# TODO: fix hard-coded N and L
|
||||||
|
Nx = Ny = Nz = 128
|
||||||
|
Lx = Ly = Lz = 250.
|
||||||
|
|
||||||
|
Ntot = Nx * Ny * Nz
|
||||||
|
x = x * Nx / Lx
|
||||||
|
y = y * Ny / Ly
|
||||||
|
z = z * Nz / Lz
|
||||||
|
|
||||||
|
qx, ix = get_cell_coord(x)
|
||||||
|
qy, iy = get_cell_coord(y)
|
||||||
|
qz, iz = get_cell_coord(z)
|
||||||
|
|
||||||
|
ix = ix.astype(int)
|
||||||
|
iy = iy.astype(int)
|
||||||
|
iz = iz.astype(int)
|
||||||
|
rx = 1.0 - qx
|
||||||
|
ry = 1.0 - qy
|
||||||
|
rz = 1.0 - qz
|
||||||
|
jx = (ix + 1) % Nx
|
||||||
|
jy = (iy + 1) % Ny
|
||||||
|
jz = (iz + 1) % Nz
|
||||||
|
|
||||||
|
rho = jnp.zeros((Ntot,))
|
||||||
|
|
||||||
|
for a in [False, True]:
|
||||||
|
for b in [False, True]:
|
||||||
|
for c in [False, True]:
|
||||||
|
ax = jx if a else ix
|
||||||
|
ay = jy if b else iy
|
||||||
|
az = jz if c else iz
|
||||||
|
ux = qx if a else rx
|
||||||
|
uy = qy if b else ry
|
||||||
|
uz = qz if c else rz
|
||||||
|
|
||||||
|
idx = az + Nz * ay + Nz * Ny * ax
|
||||||
|
rho += jnp.bincount(idx, weights=ux * uy * uz, length=Ntot)
|
||||||
|
|
||||||
|
return rho.reshape((Nx, Ny, Nz)) / (x.shape[0] / Ntot) - 1.0
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(0,5,6,7,8,9,10))
|
||||||
|
def cic_analytical_grad(self,density, x, y, z, Nx, Ny, Nz, Lx, Ly, Lz):
|
||||||
|
Ntot = Nx * Ny * Nz
|
||||||
|
x = x * Nx / Lx
|
||||||
|
y = y * Ny / Ly
|
||||||
|
z = z * Nz / Lz
|
||||||
|
|
||||||
|
qx, ix = get_cell_coord(x)
|
||||||
|
qy, iy = get_cell_coord(y)
|
||||||
|
qz, iz = get_cell_coord(z)
|
||||||
|
|
||||||
|
ix = ix.astype(int)
|
||||||
|
iy = iy.astype(int)
|
||||||
|
iz = iz.astype(int)
|
||||||
|
|
||||||
|
rx = 1.0 - qx
|
||||||
|
ry = 1.0 - qy
|
||||||
|
rz = 1.0 - qz
|
||||||
|
jx = (ix + 1) % Nx
|
||||||
|
jy = (iy + 1) % Ny
|
||||||
|
jz = (iz + 1) % Nz
|
||||||
|
|
||||||
|
adj_gradient = jnp.zeros((Nx**3,3))
|
||||||
|
|
||||||
|
for coord in np.arange(3):
|
||||||
|
if coord==0:
|
||||||
|
rx = 1
|
||||||
|
qx = -1
|
||||||
|
ry = y - iy
|
||||||
|
qy = 1 - ry
|
||||||
|
rz = z - iz
|
||||||
|
qz = 1 - rz
|
||||||
|
|
||||||
|
adj_gradient = adj_gradient.at[:,0].set(density[ix,iy,iz] * qx * qy * qz + \
|
||||||
|
density[ix,iy,jz] * qx * qy * rz + \
|
||||||
|
density[ix,jy,iz] * qx * ry * qz + \
|
||||||
|
density[ix,jy,jz] * qx * ry * rz + \
|
||||||
|
density[jx,iy,iz] * rx * qy * qz + \
|
||||||
|
density[jx,iy,jz] * rx * qy * rz + \
|
||||||
|
density[jx,jy,iz] * rx * ry * qz + \
|
||||||
|
density[jx,jy,jz] * rx * ry * rz)
|
||||||
|
|
||||||
|
elif coord==1:
|
||||||
|
rx = x - ix;
|
||||||
|
qx = 1 - rx;
|
||||||
|
ry = 1;
|
||||||
|
qy = -1;
|
||||||
|
rz = z - iz;
|
||||||
|
qz = 1 - rz;
|
||||||
|
|
||||||
|
adj_gradient = adj_gradient.at[:,1].set(density[ix,iy,iz] * qx * qy * qz + \
|
||||||
|
density[ix,iy,jz] * qx * qy * rz + \
|
||||||
|
density[ix,jy,iz] * qx * ry * qz + \
|
||||||
|
density[ix,jy,jz] * qx * ry * rz + \
|
||||||
|
density[jx,iy,iz] * rx * qy * qz + \
|
||||||
|
density[jx,iy,jz] * rx * qy * rz + \
|
||||||
|
density[jx,jy,iz] * rx * ry * qz + \
|
||||||
|
density[jx,jy,jz] * rx * ry * rz)
|
||||||
|
else:
|
||||||
|
rx = x - ix;
|
||||||
|
qx = 1 - rx;
|
||||||
|
ry = y - iy;
|
||||||
|
qy = 1 - ry;
|
||||||
|
rz = 1;
|
||||||
|
qz = -1;
|
||||||
|
|
||||||
|
adj_gradient = adj_gradient.at[:,2].set(density[ix,iy,iz] * qx * qy * qz +
|
||||||
|
density[ix,iy,jz] * qx * qy * rz +
|
||||||
|
density[ix,jy,iz] * qx * ry * qz +
|
||||||
|
density[ix,jy,jz] * qx * ry * rz +
|
||||||
|
density[jx,iy,iz] * rx * qy * qz +
|
||||||
|
density[jx,iy,jz] * rx * qy * rz +
|
||||||
|
density[jx,jy,iz] * rx * ry * qz +
|
||||||
|
density[jx,jy,jz] * rx * ry * rz)
|
||||||
|
|
||||||
|
|
||||||
|
adj_gradient*= Nx / Lx
|
||||||
|
|
||||||
|
return adj_gradient
|
||||||
|
|
||||||
|
|
||||||
|
class NullForward(borg.forward.BaseForwardModel):
|
||||||
|
"""
|
||||||
|
BORG forward model which does nothing but stores
|
||||||
|
the values of parameters to be used by the likelihood
|
||||||
|
"""
|
||||||
|
def __init__(self, box: borg.forward.BoxModel) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the NullForward class
|
||||||
|
Args:
|
||||||
|
box (borg.forward.BoxModel): The input box model.
|
||||||
|
"""
|
||||||
|
super().__init__(box, box)
|
||||||
|
self.setName("nullforward")
|
||||||
|
|
||||||
|
self.params = {}
|
||||||
|
self.setCosmoParams(borg.cosmo.CosmologicalParameters())
|
||||||
|
cosmo = self.getCosmoParams()
|
||||||
|
cosmo.n_s = 0.96241
|
||||||
|
self.setCosmoParams(cosmo)
|
||||||
|
|
||||||
|
|
||||||
|
def setModelParams(self, params: dict) -> None:
|
||||||
|
"""
|
||||||
|
Change the values of the model parameters to those given by params
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (dict): Dictionary of updated model parameters.
|
||||||
|
"""
|
||||||
|
for k, v in params.items():
|
||||||
|
self.params[k] = v
|
||||||
|
print(" ")
|
||||||
|
myprint(f'Updated model parameters: {self.params}')
|
||||||
|
|
||||||
|
def getModelParam(self, model, keyname: str):
|
||||||
|
"""
|
||||||
|
This queries the current state of the parameters keyname in model model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model
|
||||||
|
keyname (str): The name of the parameter of interest
|
||||||
|
"""
|
||||||
|
return self.params[keyname]
|
516
mgborg_emulator/likelihoods.py
Normal file
516
mgborg_emulator/likelihoods.py
Normal file
|
@ -0,0 +1,516 @@
|
||||||
|
import numpy as np
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import configparser
|
||||||
|
import warnings
|
||||||
|
import aquila_borg as borg
|
||||||
|
import symbolic_pofk.linear
|
||||||
|
import jax
|
||||||
|
from functools import partial
|
||||||
|
import ast
|
||||||
|
|
||||||
|
import utils as utils
|
||||||
|
from utils import myprint
|
||||||
|
import forwards
|
||||||
|
import networks
|
||||||
|
|
||||||
|
class GaussianLikelihood(borg.likelihood.BaseLikelihood):
|
||||||
|
"""
|
||||||
|
HADES Gaussian likelihood
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fwd: borg.forward.BaseForwardModel, param_model: forwards.NullForward, ini_file: str) -> None:
|
||||||
|
"""
|
||||||
|
Initialises the GaussianLikelihood class
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- fwd (borg.forward.BaseForwardModel): The forward model to be used in the likelihood.
|
||||||
|
- param_model (forwards.NullForward): An empty forward model for storing model parameters.
|
||||||
|
- ini_file (str): The name of the ini file containing the model and borg parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.ini_file = ini_file
|
||||||
|
|
||||||
|
myprint("Reading from configuration file: " + ini_file)
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(ini_file)
|
||||||
|
|
||||||
|
# Grid parameters
|
||||||
|
self.N = [int(config['system'][f'N{i}']) for i in range(3)]
|
||||||
|
self.L = [float(config['system'][f'L{i}']) for i in range(3)]
|
||||||
|
|
||||||
|
# What type of run we're doing
|
||||||
|
self.run_type = config['run']['run_type']
|
||||||
|
|
||||||
|
# Seed if creating mocks
|
||||||
|
self.mock_seed = int(config['mock']['seed'])
|
||||||
|
|
||||||
|
# For log-likelihood values
|
||||||
|
self.bignum = float(config['mcmc']['bignum'])
|
||||||
|
|
||||||
|
myprint(f" Init {self.N}, {self.L}")
|
||||||
|
super().__init__(fwd, self.N, self.L)
|
||||||
|
|
||||||
|
# Define the forward models
|
||||||
|
self.fwd = fwd
|
||||||
|
self.fwd_param = param_model
|
||||||
|
|
||||||
|
# Initialise model parameters
|
||||||
|
model_params = {
|
||||||
|
'logfR0':float(config['model']['logfr0'])
|
||||||
|
'gauss_sigma':float(config['model']['gauss_sigma'])
|
||||||
|
}
|
||||||
|
self.fwd_param.setModelParams(model_params)
|
||||||
|
|
||||||
|
# Initialise cosmological parameters
|
||||||
|
cpar = utils.get_cosmopar(self.ini_file)
|
||||||
|
self.fwd.setCosmoParams(cpar)
|
||||||
|
self.fwd_param.setCosmoParams(cpar)
|
||||||
|
self.updateCosmology(cpar)
|
||||||
|
myprint(f"Original cosmological parameters: {self.fwd.getCosmoParams()}")
|
||||||
|
|
||||||
|
# Initialise derivative
|
||||||
|
self.grad_like = jax.grad(self.dens2like)
|
||||||
|
|
||||||
|
|
||||||
|
def initializeLikelihood(self, state: borg.likelihood.MarkovState) -> None:
|
||||||
|
"""
|
||||||
|
Initialise the likelihood internal variables and MarkovState variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||||
|
"""
|
||||||
|
|
||||||
|
myprint("Init likelihood")
|
||||||
|
state.newArray3d("mock", *self.fwd.getOutputBoxModel().N, False)
|
||||||
|
state.newArray3d("BORG_final_density", *self.fwd.getOutputBoxModel().N, True)
|
||||||
|
|
||||||
|
|
||||||
|
def updateMetaParameters(self, state: borg.likelihood.MarkovState) -> None:
|
||||||
|
"""
|
||||||
|
Update the meta parameters of the sampler (not sampled) from the MarkovState.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- state (borg.likelihood.MarkovState): The state object to be used in the likelihood.
|
||||||
|
|
||||||
|
"""
|
||||||
|
cpar = state['cosmology']
|
||||||
|
cpar.omega_q = 1. - cpar.omega_m - cpar.omega_k
|
||||||
|
self.fwd.setCosmoParams(cpar)
|
||||||
|
self.fwd_param.setCosmoParams(cpar)
|
||||||
|
|
||||||
|
|
||||||
|
def updateCosmology(self, cosmo: borg.cosmo.CosmologicalParameters) -> None:
|
||||||
|
"""
|
||||||
|
Updates the forward model's cosmological parameters with the given values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- cosmo (borg.cosmo.CosmologicalParameters): The cosmological parameters.
|
||||||
|
|
||||||
|
"""
|
||||||
|
cpar = cosmo
|
||||||
|
|
||||||
|
# Convert sigma8 to As
|
||||||
|
cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
|
||||||
|
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
|
||||||
|
myprint(f"Updating cosmology Om = {cosmo.omega_m}, sig8 = {cosmo.sigma8}, As = {cosmo.A_s}")
|
||||||
|
|
||||||
|
cpar.omega_q = 1. - cpar.omega_m - cpar.omega_k
|
||||||
|
self.fwd.setCosmoParams(cpar)
|
||||||
|
self.fwd_param.setCosmoParams(cpar)
|
||||||
|
|
||||||
|
|
||||||
|
def generateMockData(self, s_hat: np.ndarray, state: borg.likelihood.MarkovState,) -> None:
|
||||||
|
"""
|
||||||
|
Generates mock data by simulating the forward model with the given white noise,
|
||||||
|
drawing distance tracers from the density field, computing their distance
|
||||||
|
moduli and radial velocities, and adding Gaussian noise to the appropriate
|
||||||
|
variables. Also calculates the initial negative log-likelihood of the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- s_hat (np.ndarray): The input (initial) density field.
|
||||||
|
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||||
|
- make_plot (bool, default=True): Whether to make diagnostic plots for the mock data generation
|
||||||
|
"""
|
||||||
|
if self.run_type == 'data':
|
||||||
|
raise NotImplementedError
|
||||||
|
elif self.run_type == 'mock':
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def dens2like(self, output_density: np.ndarray):
|
||||||
|
"""
|
||||||
|
Given stored data, computes the negative log-likelihood of the data
|
||||||
|
for this final density field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_density (np.ndarray): The z=0 density field.
|
||||||
|
Return:
|
||||||
|
lkl (float): The negative log-likelihood of the data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sigma = self.fwd_param.getModelParam('nullforward', 'gauss_sigma')
|
||||||
|
lkl = (self.data - output_density)**2 / (2. * sigma ** 2) + 0.5 * jnp.log(2 * jnp.pi * sigma**2)
|
||||||
|
lkl = jnp.sum(lkl)
|
||||||
|
|
||||||
|
if not jnp.isfinite(lkl):
|
||||||
|
lkl = self.bignum
|
||||||
|
|
||||||
|
return lkl
|
||||||
|
|
||||||
|
def logLikelihoodComplex(self, s_hat: np.ndarray, gradientIsNext: bool):
|
||||||
|
"""
|
||||||
|
Calculates the negative log-likelihood of the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- s_hat (np.ndarray): The input white noise.
|
||||||
|
- gradientIsNext (bool): If True, prepares the forward model for gradient calculations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The negative log-likelihood value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
N = self.fwd.getBoxModel().N[0]
|
||||||
|
L = self.fwd.getOutputBoxModel().L[0]
|
||||||
|
|
||||||
|
# Run BORG density field
|
||||||
|
output_density = np.zeros((N,N,N))
|
||||||
|
self.fwd.forwardModel_v2(s_hat)
|
||||||
|
self.fwd.getDensityFinal(output_density)
|
||||||
|
|
||||||
|
# Get velocity field
|
||||||
|
output_velocity = self.fwd_vel.getVelocityField()
|
||||||
|
|
||||||
|
self.delta = output_density
|
||||||
|
self.vel = output_velocity
|
||||||
|
|
||||||
|
L = self.dens2like(output_density, output_velocity)
|
||||||
|
myprint(f"var(s_hat): {np.var(s_hat)}, Call to logLike: {L}")
|
||||||
|
|
||||||
|
return L
|
||||||
|
|
||||||
|
|
||||||
|
def gradientLikelihoodComplex(self, s_hat: np.ndarray):
|
||||||
|
"""
|
||||||
|
Calculates the adjoint negative log-likelihood of the data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- s_hat (np.ndarray): The input density field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The adjoint negative log-likelihood gradient.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
N = self.fwd.getBoxModel().N[0]
|
||||||
|
L = self.fwd.getOutputBoxModel().L[0]
|
||||||
|
|
||||||
|
# Run BORG density field
|
||||||
|
output_density = np.zeros((N,N,N))
|
||||||
|
self.fwd.forwardModel_v2(s_hat)
|
||||||
|
self.fwd.getDensityFinal(output_density)
|
||||||
|
|
||||||
|
# getlike(dens, vel)
|
||||||
|
mygradient = self.grad_like(output_density)
|
||||||
|
mygradient = np.array(mygradient, dtype=np.float64)
|
||||||
|
|
||||||
|
self.fwd.adjointModel_v2(mygradient)
|
||||||
|
mygrad_hat = np.zeros(s_hat.shape, dtype=np.complex128)
|
||||||
|
self.fwd.getAdjointModel(mygrad_hat)
|
||||||
|
elf.fwd.clearAdjointGradient()
|
||||||
|
|
||||||
|
return mygrad_hat
|
||||||
|
|
||||||
|
|
||||||
|
def commitAuxiliaryFields(self, state: borg.likelihood.MarkovState) -> None:
|
||||||
|
"""
|
||||||
|
Commits the final density field to the Markov state.
|
||||||
|
Args:
|
||||||
|
- state (borg.state.State): The state object containing the final density field.
|
||||||
|
"""
|
||||||
|
self.updateCosmology(self.fwd.getCosmoParams())
|
||||||
|
self.dens2like(self.delta)
|
||||||
|
state["BORG_final_density"][:] = self.delta
|
||||||
|
|
||||||
|
|
||||||
|
@borg.registerGravityBuilder
|
||||||
|
def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.BoxModel, ini_file=None) -> borg.forward.BaseForwardModel:
|
||||||
|
"""
|
||||||
|
Builds the gravity model and returns the forward model chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||||
|
- box (borg.forward.BoxModel): The input box model.
|
||||||
|
- ini_file (str, default=None): The location of the ini file. If None, use borg.getIniConfigurationFilename()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
borg.forward.BaseForwardModel: The forward model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
global chain, fwd_param
|
||||||
|
myprint("Building gravity model")
|
||||||
|
|
||||||
|
if ini_file is None:
|
||||||
|
myprint("Reading from configuration file: " + borg.getIniConfigurationFilename())
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(borg.getIniConfigurationFilename())
|
||||||
|
else:
|
||||||
|
myprint("Reading from configuration file: " + ini_file)
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(ini_file)
|
||||||
|
ai = float(config['model']['ai'])
|
||||||
|
af = float(config['model']['af'])
|
||||||
|
|
||||||
|
# Cosmological parameters
|
||||||
|
if ini_file is None:
|
||||||
|
cpar = utils.get_cosmopar(borg.getIniConfigurationFilename())
|
||||||
|
else:
|
||||||
|
cpar = utils.get_cosmopar(ini_file)
|
||||||
|
chain.setCosmoParams(cpar)
|
||||||
|
|
||||||
|
# Setup forward model
|
||||||
|
chain = borg.forward.ChainForwardModel(box)
|
||||||
|
chain.addModel(borg.forward.models.HermiticEnforcer(box))
|
||||||
|
|
||||||
|
# CLASS transfer function
|
||||||
|
chain @= borg.forward.model_lib.M_PRIMORDIAL_AS(box)
|
||||||
|
transfer_class = borg.forward.model_lib.M_TRANSFER_CLASS(box, opts=dict(a_transfer=1.0))
|
||||||
|
transfer_class.setModelParams({"extra_class_arguments":{'YHe':'0.24'}})
|
||||||
|
chain @= transfer_class
|
||||||
|
|
||||||
|
if config['model']['gravity'] == 'lpt':
|
||||||
|
lpt = borg.forward.model_lib.M_LPT_CIC(
|
||||||
|
box,
|
||||||
|
opts=dict(a_initial=af,
|
||||||
|
a_final=af,
|
||||||
|
do_rsd=False,
|
||||||
|
supersampling=1,
|
||||||
|
lightcone=False,
|
||||||
|
part_factor=1.01,))
|
||||||
|
elif config['model']['gravity'] == '2lpt':
|
||||||
|
lpt = borg.forward.model_lib.M_2LPT_CIC(
|
||||||
|
box,
|
||||||
|
opts=dict(a_initial=af,
|
||||||
|
a_final=af,
|
||||||
|
do_rsd=False,
|
||||||
|
supersampling=1,
|
||||||
|
lightcone=False,
|
||||||
|
part_factor=1.01,))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(config['model']['gravity'])
|
||||||
|
|
||||||
|
lpt.accumulateAdjoint(True)
|
||||||
|
chain @= lpt
|
||||||
|
|
||||||
|
if config['emulator']['use_emulator'].lower().strip() == 'true':
|
||||||
|
myprint('Adding emulator to the chain')
|
||||||
|
|
||||||
|
# Check device:
|
||||||
|
cuda_avail = torch.cuda.is_available()
|
||||||
|
device = torch.device('cuda' if cuda_avail else 'cpu')
|
||||||
|
myprint(f'device = {device}')
|
||||||
|
if cuda_avail:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
model_weights_path = config['emulator']['model_weights_path']
|
||||||
|
myprint(f'Use model params from: {model_weights_path}')
|
||||||
|
emu_weights = torch.load(model_weights_path, map_location=device)
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
model = getattr(networks, config['emulator']['architecture'])
|
||||||
|
# if use_distilled:
|
||||||
|
# model = networks.StyledVNet_distill(1,3,3,num_filt=num_filt)
|
||||||
|
# else:
|
||||||
|
# model = networks.StyledVNet(1,3,3)
|
||||||
|
model.load_state_dict(emu_weights['model'])
|
||||||
|
|
||||||
|
use_float64 = config['emulator']['use_float64'].lower().strip() == 'true'
|
||||||
|
if use_float64:
|
||||||
|
model.double()
|
||||||
|
dtype = torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# Extract omega as style param
|
||||||
|
Om = cpar.omega_m
|
||||||
|
Om = torch.tensor([Om],dtype=dtype) # style parameter
|
||||||
|
|
||||||
|
# from emulator hacking:
|
||||||
|
Om -= torch.tensor([0.3])
|
||||||
|
Om *= torch.tensor([5.0])
|
||||||
|
if cuda_avail:
|
||||||
|
Om = Om.cuda()
|
||||||
|
|
||||||
|
# Initial positions
|
||||||
|
q = initial_pos(box.L[0],box.N[0])
|
||||||
|
|
||||||
|
# Create module in BORG chain
|
||||||
|
emu = forwards.Emulator(
|
||||||
|
box,
|
||||||
|
lpt,
|
||||||
|
model,
|
||||||
|
Om,
|
||||||
|
config['emulator']['use_pad_and_NN'].strip().lower() == 'true',
|
||||||
|
use_float64,
|
||||||
|
q,
|
||||||
|
config['emulator']['requires_grad'].strip().lower() == 'true',
|
||||||
|
cuda_avail,
|
||||||
|
debug=False
|
||||||
|
)
|
||||||
|
chain.addModel(emu)
|
||||||
|
|
||||||
|
# DIFFERENT TO LUDVIG - CHECK THIS
|
||||||
|
chain @= emu
|
||||||
|
|
||||||
|
# This is the forward model for the model parameters
|
||||||
|
fwd_param = borg.forward.ChainForwardModel(box)
|
||||||
|
mod_null = forwards.NullForward(box)
|
||||||
|
fwd_param.addModel(mod_null)
|
||||||
|
fwd_param.setCosmoParams(cpar)
|
||||||
|
|
||||||
|
return chain
|
||||||
|
|
||||||
|
_glob_model = None
|
||||||
|
_glob_cosmo = None
|
||||||
|
begin_model = None
|
||||||
|
begin_cosmo = None
|
||||||
|
|
||||||
|
def check_model_sampling(loop):
|
||||||
|
return loop.getStepID() > begin_model
|
||||||
|
|
||||||
|
def check_cosmo_sampling(loop):
|
||||||
|
return loop.getStepID() > begin_cosmo
|
||||||
|
|
||||||
|
|
||||||
|
@borg.registerSamplerBuilder
|
||||||
|
def build_sampler(
|
||||||
|
state: borg.likelihood.MarkovState,
|
||||||
|
info: borg.likelihood.LikelihoodInfo,
|
||||||
|
loop: borg.samplers.MainLoop
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Builds the sampler and returns it.
|
||||||
|
Which parameters to sample are given in the ini file.
|
||||||
|
We assume all parameters are NOT meant to be sampled, unless we find "XX_sampler_blocked = false" in the ini file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||||
|
- info (borg.likelihood.LikelihoodInfo): The likelihood information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of samplers to use.
|
||||||
|
|
||||||
|
"""
|
||||||
|
global _glob_model, _glob_cosmo, begin_model, begin_cosmo
|
||||||
|
borg.print_msg(borg.Level.std, "Hello sampler, loop is {l}, step_id={s}", l=loop, s=loop.getStepID())
|
||||||
|
|
||||||
|
myprint("Building sampler")
|
||||||
|
|
||||||
|
myprint("Reading from configuration file: " + borg.getIniConfigurationFilename())
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(borg.getIniConfigurationFilename())
|
||||||
|
end = '_sampler_blocked'
|
||||||
|
to_sample = [k[:-len(end)] for (k, v) in config['block_loop'].items() if k[-len(end):] == end and v.lower() == 'false']
|
||||||
|
myprint(f'Parameters to sample: {to_sample}')
|
||||||
|
nsamp = int(config['run']['nsamp'])
|
||||||
|
|
||||||
|
all_sampler = []
|
||||||
|
|
||||||
|
# Cosmology sampler arguments
|
||||||
|
prefix = ""
|
||||||
|
params = []
|
||||||
|
initial_values = {}
|
||||||
|
prior = {}
|
||||||
|
for p in ["omega_m", "sigma8"]:
|
||||||
|
if p not in to_sample:
|
||||||
|
continue
|
||||||
|
if p in config['prior'].keys() and p in config['cosmology'].keys():
|
||||||
|
myprint(f'Adding {p} sampler')
|
||||||
|
params.append(f"cosmology.{p}")
|
||||||
|
initial_values[f"cosmology.{p}"] = float(config['cosmology'][p])
|
||||||
|
prior[f"cosmology.{p}"] = np.array(ast.literal_eval(config['prior'][p]))
|
||||||
|
else:
|
||||||
|
s = f'Could not find {p} prior and/or default, so will not sample'
|
||||||
|
warnings.warn(s, stacklevel=2)
|
||||||
|
# Remove for later to prevent duplication
|
||||||
|
to_sample.remove(p)
|
||||||
|
|
||||||
|
begin_cosmo = int(config['mcmc']['warmup_cosmo'])
|
||||||
|
|
||||||
|
if len(params) > 0:
|
||||||
|
myprint('Adding cosmological parameter sampler')
|
||||||
|
cosmo_sampler = borg.samplers.ModelParamsSampler(prefix, params, likelihood, chain, initial_values, prior)
|
||||||
|
cosmo_sampler.setName("cosmo_sampler")
|
||||||
|
_glob_cosmo = cosmo_sampler
|
||||||
|
all_sampler.append(cosmo_sampler)
|
||||||
|
loop.push(cosmo_sampler)
|
||||||
|
loop.addToConditionGroup("warmup_cosmo", "cosmo_sampler")
|
||||||
|
loop.addConditionToConditionGroup("warmup_cosmo", partial(check_cosmo_sampling, loop))
|
||||||
|
|
||||||
|
# Model parameter sampler
|
||||||
|
prefix = ""
|
||||||
|
params = []
|
||||||
|
initial_values = {}
|
||||||
|
prior = {}
|
||||||
|
for p in to_sample:
|
||||||
|
if p in config['prior'].keys():
|
||||||
|
myprint(f'Adding {p} sampler')
|
||||||
|
params.append(p)
|
||||||
|
initial_values[f'{p}'] = float(config[f'model'][p])
|
||||||
|
if 'inf' in config['prior'][p]:
|
||||||
|
x = ast.literal_eval(config['prior'][p].replace('inf', '"inf"'))
|
||||||
|
prior[p] = np.array([xx if xx != 'inf' else np.inf for xx in x])
|
||||||
|
else:
|
||||||
|
prior[p] = np.array(ast.literal_eval(config['prior'][p]))
|
||||||
|
else:
|
||||||
|
s = f'Could not find {p} prior, so will not sample'
|
||||||
|
warnings.warn(s, stacklevel=2)
|
||||||
|
|
||||||
|
begin_model = int(config['mcmc']['warmup_model'])
|
||||||
|
|
||||||
|
if len(params) > 0:
|
||||||
|
myprint('Adding model parameter sampler')
|
||||||
|
model_sampler = borg.samplers.ModelParamsSampler(prefix, params, likelihood, fwd_param, initial_values, prior)
|
||||||
|
model_sampler.setName("model_sampler")
|
||||||
|
_glob_model = model_sampler
|
||||||
|
loop.push(model_sampler)
|
||||||
|
all_sampler.append(model_sampler)
|
||||||
|
loop.addToConditionGroup("warmup_model", "model_sampler")
|
||||||
|
loop.addConditionToConditionGroup("warmup_model", partial(check_model_sampling, loop))
|
||||||
|
|
||||||
|
print('Warmups:', begin_cosmo, begin_model)
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@borg.registerLikelihoodBuilder
|
||||||
|
def build_likelihood(state: borg.likelihood.MarkovState, info: borg.likelihood.LikelihoodInfo) -> borg.likelihood.BaseLikelihood:
|
||||||
|
"""
|
||||||
|
Builds the likelihood object and returns it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||||
|
- info (borg.likelihood.LikelihoodInfo): The likelihood information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
borg.likelihood.BaseLikelihood: The likelihood object.
|
||||||
|
|
||||||
|
"""
|
||||||
|
global likelihood, fwd_param
|
||||||
|
myprint("Building likelihood")
|
||||||
|
myprint(chain.getCosmoParams())
|
||||||
|
boxm = chain.getBoxModel()
|
||||||
|
likelihood = VelocityBORGLikelihood(chain, fwd_param, fwd_vel, borg.getIniConfigurationFilename())
|
||||||
|
return likelihood
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
NOTES:
|
||||||
|
- Currently if we update Om, then we don't seem to update the emulator
|
||||||
|
- Few arguments missing from forwards.Emulator
|
||||||
|
"""
|
42
mgborg_emulator/map2map_tools/narrow.py
Normal file
42
mgborg_emulator/map2map_tools/narrow.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_by(a, c):
|
||||||
|
"""Narrow a by size c symmetrically on all edges.
|
||||||
|
"""
|
||||||
|
ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2)
|
||||||
|
return a[ind]
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_cast(*tensors):
|
||||||
|
"""Narrow each tensor to the minimum length in each dimension.
|
||||||
|
|
||||||
|
Try to be symmetric but cut more on the right for odd difference
|
||||||
|
"""
|
||||||
|
dim_max = max(a.dim() for a in tensors)
|
||||||
|
|
||||||
|
len_min = {d: min(a.shape[d] for a in tensors) for d in range(2, dim_max)}
|
||||||
|
|
||||||
|
casted_tensors = []
|
||||||
|
for a in tensors:
|
||||||
|
for d in range(2, dim_max):
|
||||||
|
width = a.shape[d] - len_min[d]
|
||||||
|
half_width = width // 2
|
||||||
|
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||||
|
|
||||||
|
casted_tensors.append(a)
|
||||||
|
|
||||||
|
return casted_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def narrow_like(a, b):
|
||||||
|
"""Narrow a to be like b.
|
||||||
|
|
||||||
|
Try to be symmetric but cut more on the right for odd difference
|
||||||
|
"""
|
||||||
|
for d in range(2, a.dim()):
|
||||||
|
width = a.shape[d] - b.shape[d]
|
||||||
|
half_width = width // 2
|
||||||
|
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||||
|
return a
|
179
mgborg_emulator/map2map_tools/style.py
Normal file
179
mgborg_emulator/map2map_tools/style.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class PixelNorm(nn.Module):
|
||||||
|
"""Pixelwise normalization after conv layers.
|
||||||
|
|
||||||
|
See ProGAN, StyleGAN.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, eps=1e-8):
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(dim=1, keepdim=True) + eps)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearElr(nn.Module):
|
||||||
|
"""Linear layer with equalized learning rate.
|
||||||
|
|
||||||
|
See ProGAN, StyleGAN, and 1706.05350
|
||||||
|
|
||||||
|
Useful at all if not for regularization(1706.05350)?
|
||||||
|
"""
|
||||||
|
def __init__(self, in_size, out_size, bias=True, act=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.randn(out_size, in_size))
|
||||||
|
self.wnorm = 1 / math.sqrt(in_size)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_size))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
self.act = act
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.linear(x, self.weight * self.wnorm, bias=self.bias)
|
||||||
|
|
||||||
|
if self.act:
|
||||||
|
x = F.leaky_relu(x, negative_slope=0.2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvElr3d(nn.Module):
|
||||||
|
"""Conv3d layer with equalized learning rate.
|
||||||
|
|
||||||
|
See ProGAN, StyleGAN, and 1706.05350
|
||||||
|
|
||||||
|
Useful at all if not for regularization(1706.05350)?
|
||||||
|
"""
|
||||||
|
def __init__(self, in_chan, out_chan, kernel_size,
|
||||||
|
stride=1, padding=0, bias=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.randn(out_chan, in_chan, *(kernel_size,) * 3),
|
||||||
|
)
|
||||||
|
fan_in = in_chan * kernel_size ** 3
|
||||||
|
self.wnorm = 1 / math.sqrt(fan_in)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_chan))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.conv2d(
|
||||||
|
x,
|
||||||
|
self.weight * self.wnorm,
|
||||||
|
bias=self.bias,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=self.padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvStyled3d(nn.Module):
|
||||||
|
"""Convolution layer with modulation and demodulation, from StyleGAN2.
|
||||||
|
|
||||||
|
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
|
||||||
|
"""
|
||||||
|
def __init__(self, style_size, in_chan, out_chan, kernel_size=3, stride=1,
|
||||||
|
bias=True, resample=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.style_weight = nn.Parameter(torch.empty(in_chan, style_size))
|
||||||
|
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
|
||||||
|
mode='fan_in', nonlinearity='leaky_relu')
|
||||||
|
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
|
||||||
|
|
||||||
|
if resample is None:
|
||||||
|
K3 = (kernel_size,) * 3
|
||||||
|
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||||
|
self.stride = stride
|
||||||
|
self.conv = F.conv3d
|
||||||
|
elif resample == 'U':
|
||||||
|
K3 = (2,) * 3
|
||||||
|
# NOTE not clear to me why convtranspose have channels swapped
|
||||||
|
self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3))
|
||||||
|
self.stride = 2
|
||||||
|
self.conv = F.conv_transpose3d
|
||||||
|
elif resample == 'D':
|
||||||
|
K3 = (2,) * 3
|
||||||
|
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||||
|
self.stride = 2
|
||||||
|
self.conv = F.conv3d
|
||||||
|
else:
|
||||||
|
raise ValueError('resample type {} not supported'.format(resample))
|
||||||
|
self.resample = resample
|
||||||
|
|
||||||
|
nn.init.kaiming_uniform_(
|
||||||
|
self.weight, a=math.sqrt(5),
|
||||||
|
mode='fan_in', # effectively 'fan_out' for 'D'
|
||||||
|
nonlinearity='leaky_relu',
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_chan))
|
||||||
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||||
|
bound = 1 / math.sqrt(fan_in)
|
||||||
|
nn.init.uniform_(self.bias, -bound, bound)
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
def forward(self, x, s, eps=1e-8):
|
||||||
|
N, Cin, *DHWin = x.shape
|
||||||
|
C0, C1, *K3 = self.weight.shape
|
||||||
|
if self.resample == 'U':
|
||||||
|
Cin, Cout = C0, C1
|
||||||
|
else:
|
||||||
|
Cout, Cin = C0, C1
|
||||||
|
|
||||||
|
s = F.linear(s, self.style_weight, bias=self.style_bias)
|
||||||
|
|
||||||
|
# modulation
|
||||||
|
if self.resample == 'U':
|
||||||
|
s = s.reshape(N, Cin, 1, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
s = s.reshape(N, 1, Cin, 1, 1, 1)
|
||||||
|
w = self.weight * s
|
||||||
|
|
||||||
|
# demodulation
|
||||||
|
if self.resample == 'U':
|
||||||
|
fan_in_dim = (1, 3, 4, 5)
|
||||||
|
else:
|
||||||
|
fan_in_dim = (2, 3, 4, 5)
|
||||||
|
w = w * torch.rsqrt(w.pow(2).sum(dim=fan_in_dim, keepdim=True) + eps)
|
||||||
|
|
||||||
|
w = w.reshape(N * C0, C1, *K3)
|
||||||
|
x = x.reshape(1, N * Cin, *DHWin)
|
||||||
|
x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N)
|
||||||
|
_, _, *DHWout = x.shape
|
||||||
|
x = x.reshape(N, Cout, *DHWout)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class BatchNormStyled3d(nn.BatchNorm3d) :
|
||||||
|
""" Trivially does standard batch normalization, but accepts second argument
|
||||||
|
|
||||||
|
for style array that is not used
|
||||||
|
"""
|
||||||
|
def forward(self, x, s):
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
class LeakyReLUStyled(nn.LeakyReLU):
|
||||||
|
""" Trivially evaluates standard leaky ReLU, but accepts second argument
|
||||||
|
|
||||||
|
for sytle array that is not used
|
||||||
|
"""
|
||||||
|
def forward(self, x, s):
|
||||||
|
return super().forward(x)
|
141
mgborg_emulator/map2map_tools/styled_conv.py
Normal file
141
mgborg_emulator/map2map_tools/styled_conv.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
import warnings
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from map2map_tools.narrow import narrow_like
|
||||||
|
from map2map_tools.style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
|
||||||
|
|
||||||
|
|
||||||
|
class ConvStyledBlock(nn.Module):
|
||||||
|
"""Convolution blocks of the form specified by `seq`.
|
||||||
|
|
||||||
|
`seq` types:
|
||||||
|
'C': convolution specified by `kernel_size` and `stride`
|
||||||
|
'B': normalization (to be renamed to 'N')
|
||||||
|
'A': activation
|
||||||
|
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
||||||
|
'D': downsampling convolution of kernel size 2 and stride 2
|
||||||
|
"""
|
||||||
|
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||||
|
kernel_size=3, stride=1, seq='CBA'):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if out_chan is None:
|
||||||
|
out_chan = in_chan
|
||||||
|
|
||||||
|
self.style_size = style_size
|
||||||
|
self.in_chan = in_chan
|
||||||
|
self.out_chan = out_chan
|
||||||
|
if mid_chan is None:
|
||||||
|
self.mid_chan = max(in_chan, out_chan)
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
self.norm_chan = in_chan
|
||||||
|
self.idx_conv = 0
|
||||||
|
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
||||||
|
|
||||||
|
layers = [self._get_layer(l) for l in seq]
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList(layers)
|
||||||
|
|
||||||
|
def _get_layer(self, l):
|
||||||
|
if l == 'U':
|
||||||
|
in_chan, out_chan = self._setup_conv()
|
||||||
|
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||||
|
resample = 'U')
|
||||||
|
elif l == 'D':
|
||||||
|
in_chan, out_chan = self._setup_conv()
|
||||||
|
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||||
|
resample = 'D')
|
||||||
|
elif l == 'C':
|
||||||
|
in_chan, out_chan = self._setup_conv()
|
||||||
|
return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size,
|
||||||
|
stride=self.stride)
|
||||||
|
elif l == 'B':
|
||||||
|
return BatchNormStyled3d(self.norm_chan)
|
||||||
|
elif l == 'A':
|
||||||
|
return LeakyReLUStyled()
|
||||||
|
else:
|
||||||
|
raise ValueError('layer type {} not supported'.format(l))
|
||||||
|
|
||||||
|
def _setup_conv(self):
|
||||||
|
self.idx_conv += 1
|
||||||
|
|
||||||
|
in_chan = out_chan = self.mid_chan
|
||||||
|
if self.idx_conv == 1:
|
||||||
|
in_chan = self.in_chan
|
||||||
|
if self.idx_conv == self.num_conv:
|
||||||
|
out_chan = self.out_chan
|
||||||
|
|
||||||
|
self.norm_chan = out_chan
|
||||||
|
|
||||||
|
return in_chan, out_chan
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x, s)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResStyledBlock(ConvStyledBlock):
|
||||||
|
"""Residual convolution blocks of the form specified by `seq`.
|
||||||
|
Input, via a skip connection, is added to the residual followed by an
|
||||||
|
optional activation.
|
||||||
|
|
||||||
|
The skip connection is identity if `out_chan` is omitted, otherwise it uses
|
||||||
|
a size 1 "convolution", i.e. one can trigger the latter by setting
|
||||||
|
`out_chan` even if it equals `in_chan`.
|
||||||
|
|
||||||
|
A trailing `'A'` in seq can either operate before or after the addition,
|
||||||
|
depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'`
|
||||||
|
|
||||||
|
See `ConvStyledBlock` for `seq` types.
|
||||||
|
"""
|
||||||
|
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||||
|
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
|
||||||
|
if last_act is None:
|
||||||
|
last_act = seq[-1] == 'A'
|
||||||
|
elif last_act and seq[-1] != 'A':
|
||||||
|
warnings.warn(
|
||||||
|
'Disabling last_act without trailing activation in seq',
|
||||||
|
RuntimeWarning,
|
||||||
|
)
|
||||||
|
last_act = False
|
||||||
|
|
||||||
|
if last_act:
|
||||||
|
seq = seq[:-1]
|
||||||
|
|
||||||
|
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan,
|
||||||
|
kernel_size=kernel_size, stride=stride, seq=seq)
|
||||||
|
|
||||||
|
if last_act:
|
||||||
|
self.act = LeakyReLUStyled()
|
||||||
|
else:
|
||||||
|
self.act = None
|
||||||
|
|
||||||
|
if out_chan is None:
|
||||||
|
self.skip = None
|
||||||
|
else:
|
||||||
|
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
|
||||||
|
|
||||||
|
if 'U' in seq or 'D' in seq:
|
||||||
|
raise NotImplementedError('upsample and downsample layers '
|
||||||
|
'not supported yet')
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
y = x
|
||||||
|
|
||||||
|
if self.skip is not None:
|
||||||
|
y = self.skip(y, s)
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x, s)
|
||||||
|
|
||||||
|
y = narrow_like(y, x)
|
||||||
|
x += y
|
||||||
|
|
||||||
|
if self.act is not None:
|
||||||
|
x = self.act(x, s)
|
||||||
|
|
||||||
|
return x
|
159
mgborg_emulator/networks.py
Normal file
159
mgborg_emulator/networks.py
Normal file
|
@ -0,0 +1,159 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .styled_conv import ConvStyledBlock, ResStyledBlock
|
||||||
|
from .narrow import narrow_by
|
||||||
|
|
||||||
|
class StyledVNet(nn.Module):
|
||||||
|
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
|
||||||
|
"""V-Net like network with styles
|
||||||
|
|
||||||
|
See `vnet.VNet`.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# activate non-identity skip connection in residual block
|
||||||
|
# by explicitly setting out_chan
|
||||||
|
self.conv_l00 = ResStyledBlock(style_size, in_chan, 64, seq='CACA')
|
||||||
|
self.conv_l01 = ResStyledBlock(style_size, 64, 64, seq='CACA')
|
||||||
|
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DA')
|
||||||
|
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CACA')
|
||||||
|
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DA')
|
||||||
|
self.conv_l2 = ResStyledBlock(style_size, 64, 64, seq='CACA')
|
||||||
|
self.down_l2 = ConvStyledBlock(style_size, 64, seq='DA')
|
||||||
|
|
||||||
|
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CACA')
|
||||||
|
|
||||||
|
self.up_r2 = ConvStyledBlock(style_size, 64, seq='UA')
|
||||||
|
self.conv_r2 = ResStyledBlock(style_size, 128, 64, seq='CACA')
|
||||||
|
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UA')
|
||||||
|
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CACA')
|
||||||
|
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UA')
|
||||||
|
self.conv_r00 = ResStyledBlock(style_size, 128, 64, seq='CACA')
|
||||||
|
self.conv_r01 = ResStyledBlock(style_size, 64, out_chan, seq='CAC')
|
||||||
|
|
||||||
|
if bypass is None:
|
||||||
|
self.bypass = in_chan == out_chan
|
||||||
|
else:
|
||||||
|
self.bypass = bypass
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
if self.bypass:
|
||||||
|
x0 = x
|
||||||
|
|
||||||
|
x = self.conv_l00(x, s)
|
||||||
|
y0 = self.conv_l01(x, s)
|
||||||
|
x = self.down_l0(y0, s)
|
||||||
|
|
||||||
|
y1 = self.conv_l1(x, s)
|
||||||
|
x = self.down_l1(y1, s)
|
||||||
|
|
||||||
|
y2 = self.conv_l2(x, s)
|
||||||
|
x = self.down_l2(y2, s)
|
||||||
|
|
||||||
|
x = self.conv_c(x, s)
|
||||||
|
|
||||||
|
x = self.up_r2(x, s)
|
||||||
|
y2 = narrow_by(y2, 4)
|
||||||
|
x = torch.cat([y2, x], dim=1)
|
||||||
|
del y2
|
||||||
|
x = self.conv_r2(x, s)
|
||||||
|
|
||||||
|
x = self.up_r1(x, s)
|
||||||
|
y1 = narrow_by(y1, 16)
|
||||||
|
x = torch.cat([y1, x], dim=1)
|
||||||
|
del y1
|
||||||
|
x = self.conv_r1(x, s)
|
||||||
|
|
||||||
|
x = self.up_r0(x, s)
|
||||||
|
y0 = narrow_by(y0, 40)
|
||||||
|
x = torch.cat([y0, x], dim=1)
|
||||||
|
del y0
|
||||||
|
x = self.conv_r00(x, s)
|
||||||
|
x = self.conv_r01(x, s)
|
||||||
|
|
||||||
|
if self.bypass:
|
||||||
|
x0 = narrow_by(x0, 48)
|
||||||
|
x += x0
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StyledVNet_distill(nn.Module):
|
||||||
|
def __init__(self, style_size, in_chan, out_chan, num_filt=32, alpha=[1,1,1,1], bypass=None, **kwargs):
|
||||||
|
"""V-Net like network with styles
|
||||||
|
|
||||||
|
See `vnet.VNet`.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
print(f'Number of filters = {num_filt}')
|
||||||
|
if alpha is None:
|
||||||
|
alpha = [1,1,1,1]
|
||||||
|
print(f'alpha = {alpha}')
|
||||||
|
|
||||||
|
# activate non-identity skip connection in residual block
|
||||||
|
# by explicitly setting out_chan
|
||||||
|
self.conv_l00 = ResStyledBlock(style_size, in_chan, num_filt*alpha[0], seq='CACA')
|
||||||
|
self.conv_l01 = ResStyledBlock(style_size, num_filt*alpha[0], num_filt*alpha[1], seq='CACA')
|
||||||
|
self.down_l0 = ConvStyledBlock(style_size, num_filt*alpha[1], seq='DA')
|
||||||
|
self.conv_l1 = ResStyledBlock(style_size, num_filt*alpha[1], num_filt*alpha[2], seq='CACA')
|
||||||
|
self.down_l1 = ConvStyledBlock(style_size, num_filt*alpha[2], seq='DA')
|
||||||
|
self.conv_l2 = ResStyledBlock(style_size, num_filt*alpha[2], num_filt*alpha[3], seq='CACA')
|
||||||
|
self.down_l2 = ConvStyledBlock(style_size, num_filt*alpha[3], seq='DA')
|
||||||
|
|
||||||
|
self.conv_c = ResStyledBlock(style_size, num_filt*alpha[3], num_filt*alpha[3], seq='CACA')
|
||||||
|
|
||||||
|
self.up_r2 = ConvStyledBlock(style_size, num_filt*alpha[3], seq='UA')
|
||||||
|
self.conv_r2 = ResStyledBlock(style_size, num_filt*2*alpha[3], num_filt*alpha[3], seq='CACA')
|
||||||
|
self.up_r1 = ConvStyledBlock(style_size, num_filt*alpha[2], seq='UA')
|
||||||
|
self.conv_r1 = ResStyledBlock(style_size, num_filt*2*alpha[2], num_filt*alpha[2], seq='CACA')
|
||||||
|
self.up_r0 = ConvStyledBlock(style_size, num_filt*alpha[1], seq='UA')
|
||||||
|
self.conv_r00 = ResStyledBlock(style_size, num_filt*2*alpha[1], num_filt*alpha[1], seq='CACA')
|
||||||
|
self.conv_r01 = ResStyledBlock(style_size, num_filt*alpha[0], out_chan, seq='CAC')
|
||||||
|
|
||||||
|
if bypass is None:
|
||||||
|
self.bypass = in_chan == out_chan
|
||||||
|
else:
|
||||||
|
self.bypass = bypass
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
if self.bypass:
|
||||||
|
x0 = x
|
||||||
|
|
||||||
|
x = self.conv_l00(x, s)
|
||||||
|
y0 = self.conv_l01(x, s)
|
||||||
|
x = self.down_l0(y0, s)
|
||||||
|
|
||||||
|
y1 = self.conv_l1(x, s)
|
||||||
|
x = self.down_l1(y1, s)
|
||||||
|
|
||||||
|
y2 = self.conv_l2(x, s)
|
||||||
|
x = self.down_l2(y2, s)
|
||||||
|
|
||||||
|
x = self.conv_c(x, s)
|
||||||
|
|
||||||
|
x = self.up_r2(x, s)
|
||||||
|
y2 = narrow_by(y2, 4)
|
||||||
|
x = torch.cat([y2, x], dim=1)
|
||||||
|
del y2
|
||||||
|
x = self.conv_r2(x, s)
|
||||||
|
|
||||||
|
x = self.up_r1(x, s)
|
||||||
|
y1 = narrow_by(y1, 16)
|
||||||
|
x = torch.cat([y1, x], dim=1)
|
||||||
|
del y1
|
||||||
|
x = self.conv_r1(x, s)
|
||||||
|
|
||||||
|
x = self.up_r0(x, s)
|
||||||
|
y0 = narrow_by(y0, 40)
|
||||||
|
x = torch.cat([y0, x], dim=1)
|
||||||
|
del y0
|
||||||
|
x = self.conv_r00(x, s)
|
||||||
|
x = self.conv_r01(x, s)
|
||||||
|
|
||||||
|
if self.bypass:
|
||||||
|
x0 = narrow_by(x0, 48)
|
||||||
|
x += x0
|
||||||
|
|
||||||
|
return x
|
36
mgborg_emulator/utils.py
Normal file
36
mgborg_emulator/utils.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
import aquila_borg as borg
|
||||||
|
import configparser
|
||||||
|
|
||||||
|
cons = borg.console()
|
||||||
|
myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x))
|
||||||
|
|
||||||
|
def get_cosmopar(ini_file):
|
||||||
|
"""
|
||||||
|
Extract cosmological parameters from an ini file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
:ini_file (str): Path to the ini file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters
|
||||||
|
"""
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(ini_file)
|
||||||
|
|
||||||
|
cpar = borg.cosmo.CosmologicalParameters()
|
||||||
|
cpar.default()
|
||||||
|
cpar.fnl = float(config['cosmology']['fnl'])
|
||||||
|
cpar.omega_k = float(config['cosmology']['omega_k'])
|
||||||
|
cpar.omega_m = float(config['cosmology']['omega_m'])
|
||||||
|
cpar.omega_b = float(config['cosmology']['omega_b'])
|
||||||
|
cpar.omega_q = float(config['cosmology']['omega_q'])
|
||||||
|
cpar.h = float(config['cosmology']['h100'])
|
||||||
|
cpar.sigma8 = float(config['cosmology']['sigma8'])
|
||||||
|
cpar.n_s = float(config['cosmology']['n_s'])
|
||||||
|
cpar.w = float(config['cosmology']['w'])
|
||||||
|
cpar.wprime = float(config['cosmology']['wprime'])
|
||||||
|
cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
|
||||||
|
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
|
||||||
|
|
||||||
|
return cpar
|
Loading…
Add table
Reference in a new issue