Basic structure for forwards and likelihood
This commit is contained in:
parent
80c0224e71
commit
21e8fd82ec
8 changed files with 1625 additions and 0 deletions
82
conf/basic_ini.ini
Normal file
82
conf/basic_ini.ini
Normal file
|
@ -0,0 +1,82 @@
|
|||
[system]
|
||||
console_output = borg_log
|
||||
VERBOSE_LEVEL = 2
|
||||
N0 = 128
|
||||
N1 = 128
|
||||
N2 = 128
|
||||
L0 = 500.0
|
||||
L1 = 500.0
|
||||
L2 = 500.0
|
||||
corner0 = -250.0
|
||||
corner1 = -250.0
|
||||
corner2 = -250.0
|
||||
NUM_MODES = 100
|
||||
test_mode = true
|
||||
seed_cpower = true
|
||||
|
||||
[block_loop]
|
||||
hades_sampler_blocked = false
|
||||
bias_sampler_blocked= true
|
||||
gauss_sigma_sampler_blocked = false
|
||||
ares_heat = 1.0
|
||||
|
||||
[mcmc]
|
||||
number_to_generate = 10
|
||||
warmup_model = 3
|
||||
warmup_cosmo = 7
|
||||
random_ic = false
|
||||
init_random_scaling = 0.1
|
||||
bignum = 1e300
|
||||
|
||||
[hades]
|
||||
algorithm = HMC
|
||||
max_epsilon = 0.01
|
||||
max_timesteps = 50
|
||||
mixing = 1
|
||||
|
||||
[model]
|
||||
gravity = lpt
|
||||
logfR0 = -5.0
|
||||
af = 1.0
|
||||
ai = 0.05
|
||||
|
||||
[prior]
|
||||
omega_m = [0.1, 0.8]
|
||||
sigma8 = [0.1, 1.5]
|
||||
muA = [0.5, 1.5]
|
||||
alpha = [0.0, 10.0]
|
||||
sig_v = [50.0, 200.0]
|
||||
bulk_flow = [-200.0, 200.0]
|
||||
|
||||
[cosmology]
|
||||
omega_r = 0
|
||||
fnl = 0
|
||||
omega_k = 0
|
||||
omega_m = 0.315
|
||||
omega_b = 0.049
|
||||
omega_q = 0.685
|
||||
h100 = 0.68
|
||||
sigma8 = 0.81
|
||||
n_s = 0.97
|
||||
w = -1
|
||||
wprime = 0
|
||||
beta = 1.5
|
||||
z0 = 0
|
||||
|
||||
[emulator]
|
||||
use_emulator = True
|
||||
model_weights_path = /home/bartlett/mgborg_emulator/weights/emulator_weights.file_type
|
||||
architecture = StyledVNet
|
||||
use_float64 = False
|
||||
use_pad_and_NN = True
|
||||
requires_grad = False
|
||||
|
||||
[run]
|
||||
run_type = mock
|
||||
NCAT = 0
|
||||
|
||||
[mock]
|
||||
seed = 123
|
||||
|
||||
[python]
|
||||
likelihood_path = /home/bartlett/mgborg_emulator/mgborg_emulator/likelihood.py
|
470
mgborg_emulator/forwards.py
Normal file
470
mgborg_emulator/forwards.py
Normal file
|
@ -0,0 +1,470 @@
|
|||
import aquila_borg as borg
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
# Emulator imports
|
||||
import torch
|
||||
# from .model_architecture.cosmology import *
|
||||
from networks import StyledVNet, StyledVNet_distill
|
||||
|
||||
# JAX set-up
|
||||
import jax
|
||||
from functools import partial
|
||||
import jax.numpy as jnp
|
||||
from jax import vjp
|
||||
from jax.config import config as jax_config
|
||||
jax_config.update("jax_enable_x64", True)
|
||||
|
||||
from utils import myprint
|
||||
|
||||
class Emulator(borg.forward.BaseForwardModel):
|
||||
|
||||
def __init__(self, box, prev_module, NN, Om, upan, use_float64, q, requires_grad, cuda_avail, debug=False):
|
||||
super().__init__(box, box)
|
||||
|
||||
# Set-up and cosmo
|
||||
self.box = box
|
||||
self.q = q
|
||||
self.Om = Om
|
||||
self.debug = debug
|
||||
|
||||
# LPT and EMU
|
||||
self.prev_module = prev_module
|
||||
self.NN = NN
|
||||
|
||||
# Since we won't return an adjoint for the density, but for positions, we need to do:
|
||||
self.prev_module.accumulateAdjoint(True)
|
||||
|
||||
# Settings
|
||||
self.requires_grad = requires_grad
|
||||
self.use_pad_and_NN = upan
|
||||
self.use_float64 = use_float64
|
||||
self.cuda_avail = cuda_avail
|
||||
if self.use_float64:
|
||||
self.dtype = np.float64
|
||||
else:
|
||||
self.dtype = np.float32
|
||||
|
||||
def requires_grad_true(self):
|
||||
self.requires_grad = True
|
||||
|
||||
def requires_grad_false(self):
|
||||
self.requires_grad = False
|
||||
|
||||
# IO "preferences"
|
||||
def getPreferredInput(self):
|
||||
return borg.forward.PREFERRED_REAL
|
||||
|
||||
def getPreferredOutput(self):
|
||||
return borg.forward.PREFERRED_REAL
|
||||
|
||||
# Forward part
|
||||
def forwardModel_v2_impl(self, input_array):
|
||||
global requires_grad
|
||||
|
||||
if self.debug:
|
||||
myprint(f'Requires grad = {self.requires_grad}')
|
||||
|
||||
init_time = time.time()
|
||||
|
||||
# Step 0 - Extract particle positions
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
pos = np.zeros((self.prev_module.getNumberOfParticles(), 3)) #output shape: (N^3, 3)
|
||||
self.prev_module.getParticlePositions(pos)
|
||||
if self.debug:
|
||||
myprint("Step 0 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 1 - find displacements
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
disp = pos - self.q
|
||||
if self.debug:
|
||||
myprint("Step 1 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 2 - correct for particles that moved over the periodic boundary
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
disp_temp = correct_displacement_over_periodic_boundaries(disp,L=self.box.L[0],max_disp_1d=self.box.L[0]//2)
|
||||
if self.debug:
|
||||
myprint("Step 2 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 3 - reshaping initial pos and displacement
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
# not sure why order='C' is working here... not sure if it matters... could change it below
|
||||
q_reshaped = np.reshape(self.q.T, (3,self.box.N[0],self.box.N[0],self.box.N[0]), order='C') #output shape: (3, N, N, N)
|
||||
dis_in = np.reshape(disp_temp.T, (3,self.box.N[0],self.box.N[0],self.box.N[0]), order='C') #output shape: (3, N, N, N)
|
||||
if self.debug:
|
||||
self.dis_in = dis_in
|
||||
myprint("Step 3 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 4 - normalize
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
dis(dis_in)
|
||||
if self.debug:
|
||||
myprint("Step 4 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
if self.use_pad_and_NN:
|
||||
if self.debug:
|
||||
myprint('Using padding and NN.')
|
||||
|
||||
if not self.use_float64:
|
||||
dis_in = dis_in.astype(np.float32)
|
||||
|
||||
# Step 5 - padding to (3,N+48*2,N+48*2,N+48*2)
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
|
||||
# for jax adjoint
|
||||
if self.debug:
|
||||
myprint(f'in_pad shape = {dis_in.shape}')
|
||||
dis_in_padded, self.ag_pad = vjp(self.padding, dis_in) #output shape: (3, N+96, N+96, N+96)
|
||||
if self.debug:
|
||||
myprint(f'out_pad shape = {np.shape(np.asarray(dis_in_padded))}')
|
||||
myprint("Step 5 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 6 - turn into a pytorch tensor (unsquueze because batch = 1)
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
if self.use_float64:
|
||||
self.x = torch.unsqueeze(torch.tensor(np.asarray(dis_in_padded),dtype=torch.float64, requires_grad=self.requires_grad),dim=0) #output shape: (1, 3, N+96, N+96, N+96)
|
||||
else:
|
||||
self.x = torch.unsqueeze(torch.tensor(np.asarray(dis_in_padded),dtype=torch.float32, requires_grad=self.requires_grad),dim=0) #output shape: (1, 3, N+96, N+96, N+96)
|
||||
if self.cuda_avail:
|
||||
self.x = self.x.cuda()
|
||||
if self.debug:
|
||||
myprint("Step 6 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 7 - Pipe through emulator
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
if self.requires_grad:
|
||||
self.y = self.NN(self.x,self.Om) #output shape: (1, 3, N, N, N)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.y = self.NN(self.x,self.Om) #output shape: (1, 3, N, N, N)
|
||||
if self.debug:
|
||||
myprint("Step 7 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 8 - N-body sim displacement
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
dis_out = torch.squeeze(self.y).detach().cpu().numpy()
|
||||
|
||||
else:
|
||||
myprint('Skipping padding and NN.')
|
||||
dis_out = dis_in
|
||||
if self.debug:
|
||||
myprint("Step 8 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 9 - undo the normalization
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
dis(dis_out,undo=True)
|
||||
if self.debug:
|
||||
self.dis_out = dis_out
|
||||
myprint("Step 9 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 10 - convert displacement into positions
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
pos = dis_out + q_reshaped
|
||||
if self.debug:
|
||||
myprint("Step 10 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 11 - make sure everything within the box
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
pos[pos>=self.box.L[0]] -= self.box.L[0]
|
||||
pos[pos<0] += self.box.L[0]
|
||||
if self.debug:
|
||||
myprint("Step 11 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 12 - reshape positions
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
self.pos_out = pos.reshape(3,self.box.N[0]**3,order='C').T #output shape: (N^3, 3)
|
||||
if self.debug:
|
||||
myprint("Step 12 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# Step 13 - CIC
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
#self.dens_out = cic_analytical(self.pos_out[:, 0], self.pos_out[:, 1], self.pos_out[:, 2], *self.box.N + self.box.L)
|
||||
self.dens_out, _ = vjp(self.jax_cic, self.pos_out[:, 0], self.pos_out[:, 1], self.pos_out[:, 2])
|
||||
|
||||
if self.debug:
|
||||
myprint("Step 13 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
myprint('Finished forward pass of emu in %s seconds' % (time.time() - init_time))
|
||||
|
||||
def getParticlePositions(self,output_array):
|
||||
output_array[:] = self.pos_out
|
||||
|
||||
def getDensityFinal_impl(self, output_array):
|
||||
output_array[:] = self.dens_out
|
||||
|
||||
# Adjoint part
|
||||
def adjointModel_v2_impl(self, input_ag):
|
||||
|
||||
init_time = time.time()
|
||||
# reverse step 13 (CIC)
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
ag = np.asarray(self.cic_analytical_grad(input_ag,self.pos_out[:,0],self.pos_out[:,1],self.pos_out[:,2],128,128,128,250,250,250))
|
||||
ag = np.copy(ag)
|
||||
if self.debug:
|
||||
myprint("Reverse step 13 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# reverse step 11
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
ag = np.reshape(ag.T, (3,*self.box.N), order='C') #changed recently
|
||||
if self.debug:
|
||||
myprint("Reverse step 11 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# reverse step 9
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
dis(ag,undo=False)
|
||||
if self.debug:
|
||||
myprint("Reverse step 9 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
if self.use_pad_and_NN:
|
||||
# reverse step 8
|
||||
if self.debug:
|
||||
myprint('adjoint... Using padding and NN.')
|
||||
start_time = time.time()
|
||||
if self.use_float64:
|
||||
ag = torch.unsqueeze(torch.tensor(ag,dtype=torch.float64),dim=0)
|
||||
else:
|
||||
ag = torch.unsqueeze(torch.tensor(ag,dtype=torch.float32),dim=0)
|
||||
if self.cuda_avail:
|
||||
ag = ag.cuda()
|
||||
if self.debug:
|
||||
myprint("Reverse step 8 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# reverse step 7
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
ag = torch.autograd.grad(self.y, self.x, grad_outputs=ag, retain_graph=False)[0]
|
||||
if self.debug:
|
||||
myprint("Reverse step 7 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# reverse step 6
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
ag = torch.squeeze(ag).detach().cpu().numpy()
|
||||
#ag = torch.squeeze(ag).detach().numpy()
|
||||
if self.debug:
|
||||
myprint("Reverse step 6 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# reverse step 5
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
|
||||
ag = np.asarray(self.ag_pad(ag))[0] #not sure why adjoint outputs shape (1,3,128,128,128)
|
||||
if self.debug:
|
||||
myprint("Reverse step 5 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
else:
|
||||
myprint('adjoint... Skipping padding and NN.')
|
||||
|
||||
# reverse step 4
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
dis(ag,undo=True)
|
||||
if self.debug:
|
||||
myprint("Reverse step 4 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# reverse step 3
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
self.ag_pos = ag.reshape(3,self.box.N[0]**3,order='C').T
|
||||
if self.debug:
|
||||
myprint("Reverse step 3 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
# memory
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
self.ag_pos = np.copy(self.ag_pos,order='C')
|
||||
if self.debug:
|
||||
myprint("Fixing memory issue took %s seconds" % (time.time() - start_time))
|
||||
myprint('Finished backward pass of emu in %s seconds' % (time.time() - init_time))
|
||||
|
||||
def getAdjointModel_impl(self, output_ag):
|
||||
# Set adjoint gradient wrt density field to zero
|
||||
output_ag[:] = 0
|
||||
# and instead specify the adjoint gradient wrt particle positions
|
||||
self.prev_module.adjointModelParticles(ag_pos=np.array(self.ag_pos, dtype=self.dtype), ag_vel=np.zeros_like(self.ag_pos, dtype=self.dtype))
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def padding(self,x):
|
||||
return jnp.pad(x,((0,0),(48,48),(48,48),(48,48)),'wrap')
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,))
|
||||
def jax_cic(self, x, y, z):
|
||||
# TODO: fix hard-coded N and L
|
||||
Nx = Ny = Nz = 128
|
||||
Lx = Ly = Lz = 250.
|
||||
|
||||
Ntot = Nx * Ny * Nz
|
||||
x = x * Nx / Lx
|
||||
y = y * Ny / Ly
|
||||
z = z * Nz / Lz
|
||||
|
||||
qx, ix = get_cell_coord(x)
|
||||
qy, iy = get_cell_coord(y)
|
||||
qz, iz = get_cell_coord(z)
|
||||
|
||||
ix = ix.astype(int)
|
||||
iy = iy.astype(int)
|
||||
iz = iz.astype(int)
|
||||
rx = 1.0 - qx
|
||||
ry = 1.0 - qy
|
||||
rz = 1.0 - qz
|
||||
jx = (ix + 1) % Nx
|
||||
jy = (iy + 1) % Ny
|
||||
jz = (iz + 1) % Nz
|
||||
|
||||
rho = jnp.zeros((Ntot,))
|
||||
|
||||
for a in [False, True]:
|
||||
for b in [False, True]:
|
||||
for c in [False, True]:
|
||||
ax = jx if a else ix
|
||||
ay = jy if b else iy
|
||||
az = jz if c else iz
|
||||
ux = qx if a else rx
|
||||
uy = qy if b else ry
|
||||
uz = qz if c else rz
|
||||
|
||||
idx = az + Nz * ay + Nz * Ny * ax
|
||||
rho += jnp.bincount(idx, weights=ux * uy * uz, length=Ntot)
|
||||
|
||||
return rho.reshape((Nx, Ny, Nz)) / (x.shape[0] / Ntot) - 1.0
|
||||
|
||||
@partial(jax.jit, static_argnums=(0,5,6,7,8,9,10))
|
||||
def cic_analytical_grad(self,density, x, y, z, Nx, Ny, Nz, Lx, Ly, Lz):
|
||||
Ntot = Nx * Ny * Nz
|
||||
x = x * Nx / Lx
|
||||
y = y * Ny / Ly
|
||||
z = z * Nz / Lz
|
||||
|
||||
qx, ix = get_cell_coord(x)
|
||||
qy, iy = get_cell_coord(y)
|
||||
qz, iz = get_cell_coord(z)
|
||||
|
||||
ix = ix.astype(int)
|
||||
iy = iy.astype(int)
|
||||
iz = iz.astype(int)
|
||||
|
||||
rx = 1.0 - qx
|
||||
ry = 1.0 - qy
|
||||
rz = 1.0 - qz
|
||||
jx = (ix + 1) % Nx
|
||||
jy = (iy + 1) % Ny
|
||||
jz = (iz + 1) % Nz
|
||||
|
||||
adj_gradient = jnp.zeros((Nx**3,3))
|
||||
|
||||
for coord in np.arange(3):
|
||||
if coord==0:
|
||||
rx = 1
|
||||
qx = -1
|
||||
ry = y - iy
|
||||
qy = 1 - ry
|
||||
rz = z - iz
|
||||
qz = 1 - rz
|
||||
|
||||
adj_gradient = adj_gradient.at[:,0].set(density[ix,iy,iz] * qx * qy * qz + \
|
||||
density[ix,iy,jz] * qx * qy * rz + \
|
||||
density[ix,jy,iz] * qx * ry * qz + \
|
||||
density[ix,jy,jz] * qx * ry * rz + \
|
||||
density[jx,iy,iz] * rx * qy * qz + \
|
||||
density[jx,iy,jz] * rx * qy * rz + \
|
||||
density[jx,jy,iz] * rx * ry * qz + \
|
||||
density[jx,jy,jz] * rx * ry * rz)
|
||||
|
||||
elif coord==1:
|
||||
rx = x - ix;
|
||||
qx = 1 - rx;
|
||||
ry = 1;
|
||||
qy = -1;
|
||||
rz = z - iz;
|
||||
qz = 1 - rz;
|
||||
|
||||
adj_gradient = adj_gradient.at[:,1].set(density[ix,iy,iz] * qx * qy * qz + \
|
||||
density[ix,iy,jz] * qx * qy * rz + \
|
||||
density[ix,jy,iz] * qx * ry * qz + \
|
||||
density[ix,jy,jz] * qx * ry * rz + \
|
||||
density[jx,iy,iz] * rx * qy * qz + \
|
||||
density[jx,iy,jz] * rx * qy * rz + \
|
||||
density[jx,jy,iz] * rx * ry * qz + \
|
||||
density[jx,jy,jz] * rx * ry * rz)
|
||||
else:
|
||||
rx = x - ix;
|
||||
qx = 1 - rx;
|
||||
ry = y - iy;
|
||||
qy = 1 - ry;
|
||||
rz = 1;
|
||||
qz = -1;
|
||||
|
||||
adj_gradient = adj_gradient.at[:,2].set(density[ix,iy,iz] * qx * qy * qz +
|
||||
density[ix,iy,jz] * qx * qy * rz +
|
||||
density[ix,jy,iz] * qx * ry * qz +
|
||||
density[ix,jy,jz] * qx * ry * rz +
|
||||
density[jx,iy,iz] * rx * qy * qz +
|
||||
density[jx,iy,jz] * rx * qy * rz +
|
||||
density[jx,jy,iz] * rx * ry * qz +
|
||||
density[jx,jy,jz] * rx * ry * rz)
|
||||
|
||||
|
||||
adj_gradient*= Nx / Lx
|
||||
|
||||
return adj_gradient
|
||||
|
||||
|
||||
class NullForward(borg.forward.BaseForwardModel):
|
||||
"""
|
||||
BORG forward model which does nothing but stores
|
||||
the values of parameters to be used by the likelihood
|
||||
"""
|
||||
def __init__(self, box: borg.forward.BoxModel) -> None:
|
||||
"""
|
||||
Initialise the NullForward class
|
||||
Args:
|
||||
box (borg.forward.BoxModel): The input box model.
|
||||
"""
|
||||
super().__init__(box, box)
|
||||
self.setName("nullforward")
|
||||
|
||||
self.params = {}
|
||||
self.setCosmoParams(borg.cosmo.CosmologicalParameters())
|
||||
cosmo = self.getCosmoParams()
|
||||
cosmo.n_s = 0.96241
|
||||
self.setCosmoParams(cosmo)
|
||||
|
||||
|
||||
def setModelParams(self, params: dict) -> None:
|
||||
"""
|
||||
Change the values of the model parameters to those given by params
|
||||
|
||||
Args:
|
||||
params (dict): Dictionary of updated model parameters.
|
||||
"""
|
||||
for k, v in params.items():
|
||||
self.params[k] = v
|
||||
print(" ")
|
||||
myprint(f'Updated model parameters: {self.params}')
|
||||
|
||||
def getModelParam(self, model, keyname: str):
|
||||
"""
|
||||
This queries the current state of the parameters keyname in model model.
|
||||
|
||||
Args:
|
||||
model: The model
|
||||
keyname (str): The name of the parameter of interest
|
||||
"""
|
||||
return self.params[keyname]
|
516
mgborg_emulator/likelihoods.py
Normal file
516
mgborg_emulator/likelihoods.py
Normal file
|
@ -0,0 +1,516 @@
|
|||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
import configparser
|
||||
import warnings
|
||||
import aquila_borg as borg
|
||||
import symbolic_pofk.linear
|
||||
import jax
|
||||
from functools import partial
|
||||
import ast
|
||||
|
||||
import utils as utils
|
||||
from utils import myprint
|
||||
import forwards
|
||||
import networks
|
||||
|
||||
class GaussianLikelihood(borg.likelihood.BaseLikelihood):
|
||||
"""
|
||||
HADES Gaussian likelihood
|
||||
"""
|
||||
|
||||
def __init__(self, fwd: borg.forward.BaseForwardModel, param_model: forwards.NullForward, ini_file: str) -> None:
|
||||
"""
|
||||
Initialises the GaussianLikelihood class
|
||||
|
||||
Args:
|
||||
- fwd (borg.forward.BaseForwardModel): The forward model to be used in the likelihood.
|
||||
- param_model (forwards.NullForward): An empty forward model for storing model parameters.
|
||||
- ini_file (str): The name of the ini file containing the model and borg parameters.
|
||||
"""
|
||||
|
||||
self.ini_file = ini_file
|
||||
|
||||
myprint("Reading from configuration file: " + ini_file)
|
||||
config = configparser.ConfigParser()
|
||||
config.read(ini_file)
|
||||
|
||||
# Grid parameters
|
||||
self.N = [int(config['system'][f'N{i}']) for i in range(3)]
|
||||
self.L = [float(config['system'][f'L{i}']) for i in range(3)]
|
||||
|
||||
# What type of run we're doing
|
||||
self.run_type = config['run']['run_type']
|
||||
|
||||
# Seed if creating mocks
|
||||
self.mock_seed = int(config['mock']['seed'])
|
||||
|
||||
# For log-likelihood values
|
||||
self.bignum = float(config['mcmc']['bignum'])
|
||||
|
||||
myprint(f" Init {self.N}, {self.L}")
|
||||
super().__init__(fwd, self.N, self.L)
|
||||
|
||||
# Define the forward models
|
||||
self.fwd = fwd
|
||||
self.fwd_param = param_model
|
||||
|
||||
# Initialise model parameters
|
||||
model_params = {
|
||||
'logfR0':float(config['model']['logfr0'])
|
||||
'gauss_sigma':float(config['model']['gauss_sigma'])
|
||||
}
|
||||
self.fwd_param.setModelParams(model_params)
|
||||
|
||||
# Initialise cosmological parameters
|
||||
cpar = utils.get_cosmopar(self.ini_file)
|
||||
self.fwd.setCosmoParams(cpar)
|
||||
self.fwd_param.setCosmoParams(cpar)
|
||||
self.updateCosmology(cpar)
|
||||
myprint(f"Original cosmological parameters: {self.fwd.getCosmoParams()}")
|
||||
|
||||
# Initialise derivative
|
||||
self.grad_like = jax.grad(self.dens2like)
|
||||
|
||||
|
||||
def initializeLikelihood(self, state: borg.likelihood.MarkovState) -> None:
|
||||
"""
|
||||
Initialise the likelihood internal variables and MarkovState variables.
|
||||
|
||||
Args:
|
||||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||
"""
|
||||
|
||||
myprint("Init likelihood")
|
||||
state.newArray3d("mock", *self.fwd.getOutputBoxModel().N, False)
|
||||
state.newArray3d("BORG_final_density", *self.fwd.getOutputBoxModel().N, True)
|
||||
|
||||
|
||||
def updateMetaParameters(self, state: borg.likelihood.MarkovState) -> None:
|
||||
"""
|
||||
Update the meta parameters of the sampler (not sampled) from the MarkovState.
|
||||
|
||||
Args:
|
||||
- state (borg.likelihood.MarkovState): The state object to be used in the likelihood.
|
||||
|
||||
"""
|
||||
cpar = state['cosmology']
|
||||
cpar.omega_q = 1. - cpar.omega_m - cpar.omega_k
|
||||
self.fwd.setCosmoParams(cpar)
|
||||
self.fwd_param.setCosmoParams(cpar)
|
||||
|
||||
|
||||
def updateCosmology(self, cosmo: borg.cosmo.CosmologicalParameters) -> None:
|
||||
"""
|
||||
Updates the forward model's cosmological parameters with the given values.
|
||||
|
||||
Args:
|
||||
- cosmo (borg.cosmo.CosmologicalParameters): The cosmological parameters.
|
||||
|
||||
"""
|
||||
cpar = cosmo
|
||||
|
||||
# Convert sigma8 to As
|
||||
cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
|
||||
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
|
||||
myprint(f"Updating cosmology Om = {cosmo.omega_m}, sig8 = {cosmo.sigma8}, As = {cosmo.A_s}")
|
||||
|
||||
cpar.omega_q = 1. - cpar.omega_m - cpar.omega_k
|
||||
self.fwd.setCosmoParams(cpar)
|
||||
self.fwd_param.setCosmoParams(cpar)
|
||||
|
||||
|
||||
def generateMockData(self, s_hat: np.ndarray, state: borg.likelihood.MarkovState,) -> None:
|
||||
"""
|
||||
Generates mock data by simulating the forward model with the given white noise,
|
||||
drawing distance tracers from the density field, computing their distance
|
||||
moduli and radial velocities, and adding Gaussian noise to the appropriate
|
||||
variables. Also calculates the initial negative log-likelihood of the data.
|
||||
|
||||
Args:
|
||||
- s_hat (np.ndarray): The input (initial) density field.
|
||||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||
- make_plot (bool, default=True): Whether to make diagnostic plots for the mock data generation
|
||||
"""
|
||||
if self.run_type == 'data':
|
||||
raise NotImplementedError
|
||||
elif self.run_type == 'mock':
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def dens2like(self, output_density: np.ndarray):
|
||||
"""
|
||||
Given stored data, computes the negative log-likelihood of the data
|
||||
for this final density field.
|
||||
|
||||
Args:
|
||||
output_density (np.ndarray): The z=0 density field.
|
||||
Return:
|
||||
lkl (float): The negative log-likelihood of the data.
|
||||
"""
|
||||
|
||||
sigma = self.fwd_param.getModelParam('nullforward', 'gauss_sigma')
|
||||
lkl = (self.data - output_density)**2 / (2. * sigma ** 2) + 0.5 * jnp.log(2 * jnp.pi * sigma**2)
|
||||
lkl = jnp.sum(lkl)
|
||||
|
||||
if not jnp.isfinite(lkl):
|
||||
lkl = self.bignum
|
||||
|
||||
return lkl
|
||||
|
||||
def logLikelihoodComplex(self, s_hat: np.ndarray, gradientIsNext: bool):
|
||||
"""
|
||||
Calculates the negative log-likelihood of the data.
|
||||
|
||||
Args:
|
||||
- s_hat (np.ndarray): The input white noise.
|
||||
- gradientIsNext (bool): If True, prepares the forward model for gradient calculations.
|
||||
|
||||
Returns:
|
||||
The negative log-likelihood value.
|
||||
|
||||
"""
|
||||
|
||||
N = self.fwd.getBoxModel().N[0]
|
||||
L = self.fwd.getOutputBoxModel().L[0]
|
||||
|
||||
# Run BORG density field
|
||||
output_density = np.zeros((N,N,N))
|
||||
self.fwd.forwardModel_v2(s_hat)
|
||||
self.fwd.getDensityFinal(output_density)
|
||||
|
||||
# Get velocity field
|
||||
output_velocity = self.fwd_vel.getVelocityField()
|
||||
|
||||
self.delta = output_density
|
||||
self.vel = output_velocity
|
||||
|
||||
L = self.dens2like(output_density, output_velocity)
|
||||
myprint(f"var(s_hat): {np.var(s_hat)}, Call to logLike: {L}")
|
||||
|
||||
return L
|
||||
|
||||
|
||||
def gradientLikelihoodComplex(self, s_hat: np.ndarray):
|
||||
"""
|
||||
Calculates the adjoint negative log-likelihood of the data.
|
||||
|
||||
Args:
|
||||
- s_hat (np.ndarray): The input density field.
|
||||
|
||||
Returns:
|
||||
The adjoint negative log-likelihood gradient.
|
||||
|
||||
"""
|
||||
|
||||
N = self.fwd.getBoxModel().N[0]
|
||||
L = self.fwd.getOutputBoxModel().L[0]
|
||||
|
||||
# Run BORG density field
|
||||
output_density = np.zeros((N,N,N))
|
||||
self.fwd.forwardModel_v2(s_hat)
|
||||
self.fwd.getDensityFinal(output_density)
|
||||
|
||||
# getlike(dens, vel)
|
||||
mygradient = self.grad_like(output_density)
|
||||
mygradient = np.array(mygradient, dtype=np.float64)
|
||||
|
||||
self.fwd.adjointModel_v2(mygradient)
|
||||
mygrad_hat = np.zeros(s_hat.shape, dtype=np.complex128)
|
||||
self.fwd.getAdjointModel(mygrad_hat)
|
||||
elf.fwd.clearAdjointGradient()
|
||||
|
||||
return mygrad_hat
|
||||
|
||||
|
||||
def commitAuxiliaryFields(self, state: borg.likelihood.MarkovState) -> None:
|
||||
"""
|
||||
Commits the final density field to the Markov state.
|
||||
Args:
|
||||
- state (borg.state.State): The state object containing the final density field.
|
||||
"""
|
||||
self.updateCosmology(self.fwd.getCosmoParams())
|
||||
self.dens2like(self.delta)
|
||||
state["BORG_final_density"][:] = self.delta
|
||||
|
||||
|
||||
@borg.registerGravityBuilder
|
||||
def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.BoxModel, ini_file=None) -> borg.forward.BaseForwardModel:
|
||||
"""
|
||||
Builds the gravity model and returns the forward model chain.
|
||||
|
||||
Args:
|
||||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||
- box (borg.forward.BoxModel): The input box model.
|
||||
- ini_file (str, default=None): The location of the ini file. If None, use borg.getIniConfigurationFilename()
|
||||
|
||||
Returns:
|
||||
borg.forward.BaseForwardModel: The forward model.
|
||||
|
||||
"""
|
||||
|
||||
global chain, fwd_param
|
||||
myprint("Building gravity model")
|
||||
|
||||
if ini_file is None:
|
||||
myprint("Reading from configuration file: " + borg.getIniConfigurationFilename())
|
||||
config = configparser.ConfigParser()
|
||||
config.read(borg.getIniConfigurationFilename())
|
||||
else:
|
||||
myprint("Reading from configuration file: " + ini_file)
|
||||
config = configparser.ConfigParser()
|
||||
config.read(ini_file)
|
||||
ai = float(config['model']['ai'])
|
||||
af = float(config['model']['af'])
|
||||
|
||||
# Cosmological parameters
|
||||
if ini_file is None:
|
||||
cpar = utils.get_cosmopar(borg.getIniConfigurationFilename())
|
||||
else:
|
||||
cpar = utils.get_cosmopar(ini_file)
|
||||
chain.setCosmoParams(cpar)
|
||||
|
||||
# Setup forward model
|
||||
chain = borg.forward.ChainForwardModel(box)
|
||||
chain.addModel(borg.forward.models.HermiticEnforcer(box))
|
||||
|
||||
# CLASS transfer function
|
||||
chain @= borg.forward.model_lib.M_PRIMORDIAL_AS(box)
|
||||
transfer_class = borg.forward.model_lib.M_TRANSFER_CLASS(box, opts=dict(a_transfer=1.0))
|
||||
transfer_class.setModelParams({"extra_class_arguments":{'YHe':'0.24'}})
|
||||
chain @= transfer_class
|
||||
|
||||
if config['model']['gravity'] == 'lpt':
|
||||
lpt = borg.forward.model_lib.M_LPT_CIC(
|
||||
box,
|
||||
opts=dict(a_initial=af,
|
||||
a_final=af,
|
||||
do_rsd=False,
|
||||
supersampling=1,
|
||||
lightcone=False,
|
||||
part_factor=1.01,))
|
||||
elif config['model']['gravity'] == '2lpt':
|
||||
lpt = borg.forward.model_lib.M_2LPT_CIC(
|
||||
box,
|
||||
opts=dict(a_initial=af,
|
||||
a_final=af,
|
||||
do_rsd=False,
|
||||
supersampling=1,
|
||||
lightcone=False,
|
||||
part_factor=1.01,))
|
||||
else:
|
||||
raise NotImplementedError(config['model']['gravity'])
|
||||
|
||||
lpt.accumulateAdjoint(True)
|
||||
chain @= lpt
|
||||
|
||||
if config['emulator']['use_emulator'].lower().strip() == 'true':
|
||||
myprint('Adding emulator to the chain')
|
||||
|
||||
# Check device:
|
||||
cuda_avail = torch.cuda.is_available()
|
||||
device = torch.device('cuda' if cuda_avail else 'cpu')
|
||||
myprint(f'device = {device}')
|
||||
if cuda_avail:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Load weights
|
||||
model_weights_path = config['emulator']['model_weights_path']
|
||||
myprint(f'Use model params from: {model_weights_path}')
|
||||
emu_weights = torch.load(model_weights_path, map_location=device)
|
||||
|
||||
# Initialize model
|
||||
model = getattr(networks, config['emulator']['architecture'])
|
||||
# if use_distilled:
|
||||
# model = networks.StyledVNet_distill(1,3,3,num_filt=num_filt)
|
||||
# else:
|
||||
# model = networks.StyledVNet(1,3,3)
|
||||
model.load_state_dict(emu_weights['model'])
|
||||
|
||||
use_float64 = config['emulator']['use_float64'].lower().strip() == 'true'
|
||||
if use_float64:
|
||||
model.double()
|
||||
dtype = torch.float64
|
||||
else:
|
||||
dtype = torch.float32
|
||||
model.to(device)
|
||||
|
||||
# Extract omega as style param
|
||||
Om = cpar.omega_m
|
||||
Om = torch.tensor([Om],dtype=dtype) # style parameter
|
||||
|
||||
# from emulator hacking:
|
||||
Om -= torch.tensor([0.3])
|
||||
Om *= torch.tensor([5.0])
|
||||
if cuda_avail:
|
||||
Om = Om.cuda()
|
||||
|
||||
# Initial positions
|
||||
q = initial_pos(box.L[0],box.N[0])
|
||||
|
||||
# Create module in BORG chain
|
||||
emu = forwards.Emulator(
|
||||
box,
|
||||
lpt,
|
||||
model,
|
||||
Om,
|
||||
config['emulator']['use_pad_and_NN'].strip().lower() == 'true',
|
||||
use_float64,
|
||||
q,
|
||||
config['emulator']['requires_grad'].strip().lower() == 'true',
|
||||
cuda_avail,
|
||||
debug=False
|
||||
)
|
||||
chain.addModel(emu)
|
||||
|
||||
# DIFFERENT TO LUDVIG - CHECK THIS
|
||||
chain @= emu
|
||||
|
||||
# This is the forward model for the model parameters
|
||||
fwd_param = borg.forward.ChainForwardModel(box)
|
||||
mod_null = forwards.NullForward(box)
|
||||
fwd_param.addModel(mod_null)
|
||||
fwd_param.setCosmoParams(cpar)
|
||||
|
||||
return chain
|
||||
|
||||
_glob_model = None
|
||||
_glob_cosmo = None
|
||||
begin_model = None
|
||||
begin_cosmo = None
|
||||
|
||||
def check_model_sampling(loop):
|
||||
return loop.getStepID() > begin_model
|
||||
|
||||
def check_cosmo_sampling(loop):
|
||||
return loop.getStepID() > begin_cosmo
|
||||
|
||||
|
||||
@borg.registerSamplerBuilder
|
||||
def build_sampler(
|
||||
state: borg.likelihood.MarkovState,
|
||||
info: borg.likelihood.LikelihoodInfo,
|
||||
loop: borg.samplers.MainLoop
|
||||
):
|
||||
"""
|
||||
Builds the sampler and returns it.
|
||||
Which parameters to sample are given in the ini file.
|
||||
We assume all parameters are NOT meant to be sampled, unless we find "XX_sampler_blocked = false" in the ini file
|
||||
|
||||
Args:
|
||||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||
- info (borg.likelihood.LikelihoodInfo): The likelihood information.
|
||||
|
||||
Returns:
|
||||
List of samplers to use.
|
||||
|
||||
"""
|
||||
global _glob_model, _glob_cosmo, begin_model, begin_cosmo
|
||||
borg.print_msg(borg.Level.std, "Hello sampler, loop is {l}, step_id={s}", l=loop, s=loop.getStepID())
|
||||
|
||||
myprint("Building sampler")
|
||||
|
||||
myprint("Reading from configuration file: " + borg.getIniConfigurationFilename())
|
||||
config = configparser.ConfigParser()
|
||||
config.read(borg.getIniConfigurationFilename())
|
||||
end = '_sampler_blocked'
|
||||
to_sample = [k[:-len(end)] for (k, v) in config['block_loop'].items() if k[-len(end):] == end and v.lower() == 'false']
|
||||
myprint(f'Parameters to sample: {to_sample}')
|
||||
nsamp = int(config['run']['nsamp'])
|
||||
|
||||
all_sampler = []
|
||||
|
||||
# Cosmology sampler arguments
|
||||
prefix = ""
|
||||
params = []
|
||||
initial_values = {}
|
||||
prior = {}
|
||||
for p in ["omega_m", "sigma8"]:
|
||||
if p not in to_sample:
|
||||
continue
|
||||
if p in config['prior'].keys() and p in config['cosmology'].keys():
|
||||
myprint(f'Adding {p} sampler')
|
||||
params.append(f"cosmology.{p}")
|
||||
initial_values[f"cosmology.{p}"] = float(config['cosmology'][p])
|
||||
prior[f"cosmology.{p}"] = np.array(ast.literal_eval(config['prior'][p]))
|
||||
else:
|
||||
s = f'Could not find {p} prior and/or default, so will not sample'
|
||||
warnings.warn(s, stacklevel=2)
|
||||
# Remove for later to prevent duplication
|
||||
to_sample.remove(p)
|
||||
|
||||
begin_cosmo = int(config['mcmc']['warmup_cosmo'])
|
||||
|
||||
if len(params) > 0:
|
||||
myprint('Adding cosmological parameter sampler')
|
||||
cosmo_sampler = borg.samplers.ModelParamsSampler(prefix, params, likelihood, chain, initial_values, prior)
|
||||
cosmo_sampler.setName("cosmo_sampler")
|
||||
_glob_cosmo = cosmo_sampler
|
||||
all_sampler.append(cosmo_sampler)
|
||||
loop.push(cosmo_sampler)
|
||||
loop.addToConditionGroup("warmup_cosmo", "cosmo_sampler")
|
||||
loop.addConditionToConditionGroup("warmup_cosmo", partial(check_cosmo_sampling, loop))
|
||||
|
||||
# Model parameter sampler
|
||||
prefix = ""
|
||||
params = []
|
||||
initial_values = {}
|
||||
prior = {}
|
||||
for p in to_sample:
|
||||
if p in config['prior'].keys():
|
||||
myprint(f'Adding {p} sampler')
|
||||
params.append(p)
|
||||
initial_values[f'{p}'] = float(config[f'model'][p])
|
||||
if 'inf' in config['prior'][p]:
|
||||
x = ast.literal_eval(config['prior'][p].replace('inf', '"inf"'))
|
||||
prior[p] = np.array([xx if xx != 'inf' else np.inf for xx in x])
|
||||
else:
|
||||
prior[p] = np.array(ast.literal_eval(config['prior'][p]))
|
||||
else:
|
||||
s = f'Could not find {p} prior, so will not sample'
|
||||
warnings.warn(s, stacklevel=2)
|
||||
|
||||
begin_model = int(config['mcmc']['warmup_model'])
|
||||
|
||||
if len(params) > 0:
|
||||
myprint('Adding model parameter sampler')
|
||||
model_sampler = borg.samplers.ModelParamsSampler(prefix, params, likelihood, fwd_param, initial_values, prior)
|
||||
model_sampler.setName("model_sampler")
|
||||
_glob_model = model_sampler
|
||||
loop.push(model_sampler)
|
||||
all_sampler.append(model_sampler)
|
||||
loop.addToConditionGroup("warmup_model", "model_sampler")
|
||||
loop.addConditionToConditionGroup("warmup_model", partial(check_model_sampling, loop))
|
||||
|
||||
print('Warmups:', begin_cosmo, begin_model)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
@borg.registerLikelihoodBuilder
|
||||
def build_likelihood(state: borg.likelihood.MarkovState, info: borg.likelihood.LikelihoodInfo) -> borg.likelihood.BaseLikelihood:
|
||||
"""
|
||||
Builds the likelihood object and returns it.
|
||||
|
||||
Args:
|
||||
- state (borg.likelihood.MarkovState): The Markov state object to be used in the likelihood.
|
||||
- info (borg.likelihood.LikelihoodInfo): The likelihood information.
|
||||
|
||||
Returns:
|
||||
borg.likelihood.BaseLikelihood: The likelihood object.
|
||||
|
||||
"""
|
||||
global likelihood, fwd_param
|
||||
myprint("Building likelihood")
|
||||
myprint(chain.getCosmoParams())
|
||||
boxm = chain.getBoxModel()
|
||||
likelihood = VelocityBORGLikelihood(chain, fwd_param, fwd_vel, borg.getIniConfigurationFilename())
|
||||
return likelihood
|
||||
|
||||
|
||||
"""
|
||||
NOTES:
|
||||
- Currently if we update Om, then we don't seem to update the emulator
|
||||
- Few arguments missing from forwards.Emulator
|
||||
"""
|
42
mgborg_emulator/map2map_tools/narrow.py
Normal file
42
mgborg_emulator/map2map_tools/narrow.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def narrow_by(a, c):
|
||||
"""Narrow a by size c symmetrically on all edges.
|
||||
"""
|
||||
ind = (slice(None),) * 2 + (slice(c, -c),) * (a.dim() - 2)
|
||||
return a[ind]
|
||||
|
||||
|
||||
def narrow_cast(*tensors):
|
||||
"""Narrow each tensor to the minimum length in each dimension.
|
||||
|
||||
Try to be symmetric but cut more on the right for odd difference
|
||||
"""
|
||||
dim_max = max(a.dim() for a in tensors)
|
||||
|
||||
len_min = {d: min(a.shape[d] for a in tensors) for d in range(2, dim_max)}
|
||||
|
||||
casted_tensors = []
|
||||
for a in tensors:
|
||||
for d in range(2, dim_max):
|
||||
width = a.shape[d] - len_min[d]
|
||||
half_width = width // 2
|
||||
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||
|
||||
casted_tensors.append(a)
|
||||
|
||||
return casted_tensors
|
||||
|
||||
|
||||
def narrow_like(a, b):
|
||||
"""Narrow a to be like b.
|
||||
|
||||
Try to be symmetric but cut more on the right for odd difference
|
||||
"""
|
||||
for d in range(2, a.dim()):
|
||||
width = a.shape[d] - b.shape[d]
|
||||
half_width = width // 2
|
||||
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||
return a
|
179
mgborg_emulator/map2map_tools/style.py
Normal file
179
mgborg_emulator/map2map_tools/style.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
"""Pixelwise normalization after conv layers.
|
||||
|
||||
See ProGAN, StyleGAN.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, eps=1e-8):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=1, keepdim=True) + eps)
|
||||
|
||||
|
||||
class LinearElr(nn.Module):
|
||||
"""Linear layer with equalized learning rate.
|
||||
|
||||
See ProGAN, StyleGAN, and 1706.05350
|
||||
|
||||
Useful at all if not for regularization(1706.05350)?
|
||||
"""
|
||||
def __init__(self, in_size, out_size, bias=True, act=None):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(torch.randn(out_size, in_size))
|
||||
self.wnorm = 1 / math.sqrt(in_size)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_size))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.act = act
|
||||
|
||||
def forward(self, x):
|
||||
x = F.linear(x, self.weight * self.wnorm, bias=self.bias)
|
||||
|
||||
if self.act:
|
||||
x = F.leaky_relu(x, negative_slope=0.2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvElr3d(nn.Module):
|
||||
"""Conv3d layer with equalized learning rate.
|
||||
|
||||
See ProGAN, StyleGAN, and 1706.05350
|
||||
|
||||
Useful at all if not for regularization(1706.05350)?
|
||||
"""
|
||||
def __init__(self, in_chan, out_chan, kernel_size,
|
||||
stride=1, padding=0, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(out_chan, in_chan, *(kernel_size,) * 3),
|
||||
)
|
||||
fan_in = in_chan * kernel_size ** 3
|
||||
self.wnorm = 1 / math.sqrt(fan_in)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_chan))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
|
||||
def forward(self, x):
|
||||
x = F.conv2d(
|
||||
x,
|
||||
self.weight * self.wnorm,
|
||||
bias=self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ConvStyled3d(nn.Module):
|
||||
"""Convolution layer with modulation and demodulation, from StyleGAN2.
|
||||
|
||||
Weight and bias initialization from `torch.nn._ConvNd.reset_parameters()`.
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan, kernel_size=3, stride=1,
|
||||
bias=True, resample=None):
|
||||
super().__init__()
|
||||
|
||||
self.style_weight = nn.Parameter(torch.empty(in_chan, style_size))
|
||||
nn.init.kaiming_uniform_(self.style_weight, a=math.sqrt(5),
|
||||
mode='fan_in', nonlinearity='leaky_relu')
|
||||
self.style_bias = nn.Parameter(torch.ones(in_chan)) # NOTE: init to 1
|
||||
|
||||
if resample is None:
|
||||
K3 = (kernel_size,) * 3
|
||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||
self.stride = stride
|
||||
self.conv = F.conv3d
|
||||
elif resample == 'U':
|
||||
K3 = (2,) * 3
|
||||
# NOTE not clear to me why convtranspose have channels swapped
|
||||
self.weight = nn.Parameter(torch.empty(in_chan, out_chan, *K3))
|
||||
self.stride = 2
|
||||
self.conv = F.conv_transpose3d
|
||||
elif resample == 'D':
|
||||
K3 = (2,) * 3
|
||||
self.weight = nn.Parameter(torch.empty(out_chan, in_chan, *K3))
|
||||
self.stride = 2
|
||||
self.conv = F.conv3d
|
||||
else:
|
||||
raise ValueError('resample type {} not supported'.format(resample))
|
||||
self.resample = resample
|
||||
|
||||
nn.init.kaiming_uniform_(
|
||||
self.weight, a=math.sqrt(5),
|
||||
mode='fan_in', # effectively 'fan_out' for 'D'
|
||||
nonlinearity='leaky_relu',
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_chan))
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x, s, eps=1e-8):
|
||||
N, Cin, *DHWin = x.shape
|
||||
C0, C1, *K3 = self.weight.shape
|
||||
if self.resample == 'U':
|
||||
Cin, Cout = C0, C1
|
||||
else:
|
||||
Cout, Cin = C0, C1
|
||||
|
||||
s = F.linear(s, self.style_weight, bias=self.style_bias)
|
||||
|
||||
# modulation
|
||||
if self.resample == 'U':
|
||||
s = s.reshape(N, Cin, 1, 1, 1, 1)
|
||||
else:
|
||||
s = s.reshape(N, 1, Cin, 1, 1, 1)
|
||||
w = self.weight * s
|
||||
|
||||
# demodulation
|
||||
if self.resample == 'U':
|
||||
fan_in_dim = (1, 3, 4, 5)
|
||||
else:
|
||||
fan_in_dim = (2, 3, 4, 5)
|
||||
w = w * torch.rsqrt(w.pow(2).sum(dim=fan_in_dim, keepdim=True) + eps)
|
||||
|
||||
w = w.reshape(N * C0, C1, *K3)
|
||||
x = x.reshape(1, N * Cin, *DHWin)
|
||||
x = self.conv(x, w, bias=self.bias, stride=self.stride, groups=N)
|
||||
_, _, *DHWout = x.shape
|
||||
x = x.reshape(N, Cout, *DHWout)
|
||||
|
||||
return x
|
||||
|
||||
class BatchNormStyled3d(nn.BatchNorm3d) :
|
||||
""" Trivially does standard batch normalization, but accepts second argument
|
||||
|
||||
for style array that is not used
|
||||
"""
|
||||
def forward(self, x, s):
|
||||
return super().forward(x)
|
||||
|
||||
class LeakyReLUStyled(nn.LeakyReLU):
|
||||
""" Trivially evaluates standard leaky ReLU, but accepts second argument
|
||||
|
||||
for sytle array that is not used
|
||||
"""
|
||||
def forward(self, x, s):
|
||||
return super().forward(x)
|
141
mgborg_emulator/map2map_tools/styled_conv.py
Normal file
141
mgborg_emulator/map2map_tools/styled_conv.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import warnings
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from map2map_tools.narrow import narrow_like
|
||||
from map2map_tools.style import ConvStyled3d, BatchNormStyled3d, LeakyReLUStyled
|
||||
|
||||
|
||||
class ConvStyledBlock(nn.Module):
|
||||
"""Convolution blocks of the form specified by `seq`.
|
||||
|
||||
`seq` types:
|
||||
'C': convolution specified by `kernel_size` and `stride`
|
||||
'B': normalization (to be renamed to 'N')
|
||||
'A': activation
|
||||
'U': upsampling transposed convolution of kernel size 2 and stride 2
|
||||
'D': downsampling convolution of kernel size 2 and stride 2
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||
kernel_size=3, stride=1, seq='CBA'):
|
||||
super().__init__()
|
||||
|
||||
if out_chan is None:
|
||||
out_chan = in_chan
|
||||
|
||||
self.style_size = style_size
|
||||
self.in_chan = in_chan
|
||||
self.out_chan = out_chan
|
||||
if mid_chan is None:
|
||||
self.mid_chan = max(in_chan, out_chan)
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
|
||||
self.norm_chan = in_chan
|
||||
self.idx_conv = 0
|
||||
self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])
|
||||
|
||||
layers = [self._get_layer(l) for l in seq]
|
||||
|
||||
self.convs = nn.ModuleList(layers)
|
||||
|
||||
def _get_layer(self, l):
|
||||
if l == 'U':
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||
resample = 'U')
|
||||
elif l == 'D':
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return ConvStyled3d(self.style_size, in_chan, out_chan, 2, stride=2,
|
||||
resample = 'D')
|
||||
elif l == 'C':
|
||||
in_chan, out_chan = self._setup_conv()
|
||||
return ConvStyled3d(self.style_size, in_chan, out_chan, self.kernel_size,
|
||||
stride=self.stride)
|
||||
elif l == 'B':
|
||||
return BatchNormStyled3d(self.norm_chan)
|
||||
elif l == 'A':
|
||||
return LeakyReLUStyled()
|
||||
else:
|
||||
raise ValueError('layer type {} not supported'.format(l))
|
||||
|
||||
def _setup_conv(self):
|
||||
self.idx_conv += 1
|
||||
|
||||
in_chan = out_chan = self.mid_chan
|
||||
if self.idx_conv == 1:
|
||||
in_chan = self.in_chan
|
||||
if self.idx_conv == self.num_conv:
|
||||
out_chan = self.out_chan
|
||||
|
||||
self.norm_chan = out_chan
|
||||
|
||||
return in_chan, out_chan
|
||||
|
||||
def forward(self, x, s):
|
||||
for l in self.convs:
|
||||
x = l(x, s)
|
||||
return x
|
||||
|
||||
|
||||
class ResStyledBlock(ConvStyledBlock):
|
||||
"""Residual convolution blocks of the form specified by `seq`.
|
||||
Input, via a skip connection, is added to the residual followed by an
|
||||
optional activation.
|
||||
|
||||
The skip connection is identity if `out_chan` is omitted, otherwise it uses
|
||||
a size 1 "convolution", i.e. one can trigger the latter by setting
|
||||
`out_chan` even if it equals `in_chan`.
|
||||
|
||||
A trailing `'A'` in seq can either operate before or after the addition,
|
||||
depending on the boolean value of `last_act`, defaulting to `seq[-1] == 'A'`
|
||||
|
||||
See `ConvStyledBlock` for `seq` types.
|
||||
"""
|
||||
def __init__(self, style_size, in_chan, out_chan=None, mid_chan=None,
|
||||
kernel_size=3, stride=1, seq='CBACBA', last_act=None):
|
||||
if last_act is None:
|
||||
last_act = seq[-1] == 'A'
|
||||
elif last_act and seq[-1] != 'A':
|
||||
warnings.warn(
|
||||
'Disabling last_act without trailing activation in seq',
|
||||
RuntimeWarning,
|
||||
)
|
||||
last_act = False
|
||||
|
||||
if last_act:
|
||||
seq = seq[:-1]
|
||||
|
||||
super().__init__(style_size, in_chan, out_chan=out_chan, mid_chan=mid_chan,
|
||||
kernel_size=kernel_size, stride=stride, seq=seq)
|
||||
|
||||
if last_act:
|
||||
self.act = LeakyReLUStyled()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
if out_chan is None:
|
||||
self.skip = None
|
||||
else:
|
||||
self.skip = ConvStyled3d(style_size, in_chan, out_chan, 1)
|
||||
|
||||
if 'U' in seq or 'D' in seq:
|
||||
raise NotImplementedError('upsample and downsample layers '
|
||||
'not supported yet')
|
||||
|
||||
def forward(self, x, s):
|
||||
y = x
|
||||
|
||||
if self.skip is not None:
|
||||
y = self.skip(y, s)
|
||||
|
||||
for l in self.convs:
|
||||
x = l(x, s)
|
||||
|
||||
y = narrow_like(y, x)
|
||||
x += y
|
||||
|
||||
if self.act is not None:
|
||||
x = self.act(x, s)
|
||||
|
||||
return x
|
159
mgborg_emulator/networks.py
Normal file
159
mgborg_emulator/networks.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .styled_conv import ConvStyledBlock, ResStyledBlock
|
||||
from .narrow import narrow_by
|
||||
|
||||
class StyledVNet(nn.Module):
|
||||
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
|
||||
"""V-Net like network with styles
|
||||
|
||||
See `vnet.VNet`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# activate non-identity skip connection in residual block
|
||||
# by explicitly setting out_chan
|
||||
self.conv_l00 = ResStyledBlock(style_size, in_chan, 64, seq='CACA')
|
||||
self.conv_l01 = ResStyledBlock(style_size, 64, 64, seq='CACA')
|
||||
self.down_l0 = ConvStyledBlock(style_size, 64, seq='DA')
|
||||
self.conv_l1 = ResStyledBlock(style_size, 64, 64, seq='CACA')
|
||||
self.down_l1 = ConvStyledBlock(style_size, 64, seq='DA')
|
||||
self.conv_l2 = ResStyledBlock(style_size, 64, 64, seq='CACA')
|
||||
self.down_l2 = ConvStyledBlock(style_size, 64, seq='DA')
|
||||
|
||||
self.conv_c = ResStyledBlock(style_size, 64, 64, seq='CACA')
|
||||
|
||||
self.up_r2 = ConvStyledBlock(style_size, 64, seq='UA')
|
||||
self.conv_r2 = ResStyledBlock(style_size, 128, 64, seq='CACA')
|
||||
self.up_r1 = ConvStyledBlock(style_size, 64, seq='UA')
|
||||
self.conv_r1 = ResStyledBlock(style_size, 128, 64, seq='CACA')
|
||||
self.up_r0 = ConvStyledBlock(style_size, 64, seq='UA')
|
||||
self.conv_r00 = ResStyledBlock(style_size, 128, 64, seq='CACA')
|
||||
self.conv_r01 = ResStyledBlock(style_size, 64, out_chan, seq='CAC')
|
||||
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
else:
|
||||
self.bypass = bypass
|
||||
|
||||
def forward(self, x, s):
|
||||
if self.bypass:
|
||||
x0 = x
|
||||
|
||||
x = self.conv_l00(x, s)
|
||||
y0 = self.conv_l01(x, s)
|
||||
x = self.down_l0(y0, s)
|
||||
|
||||
y1 = self.conv_l1(x, s)
|
||||
x = self.down_l1(y1, s)
|
||||
|
||||
y2 = self.conv_l2(x, s)
|
||||
x = self.down_l2(y2, s)
|
||||
|
||||
x = self.conv_c(x, s)
|
||||
|
||||
x = self.up_r2(x, s)
|
||||
y2 = narrow_by(y2, 4)
|
||||
x = torch.cat([y2, x], dim=1)
|
||||
del y2
|
||||
x = self.conv_r2(x, s)
|
||||
|
||||
x = self.up_r1(x, s)
|
||||
y1 = narrow_by(y1, 16)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x, s)
|
||||
|
||||
x = self.up_r0(x, s)
|
||||
y0 = narrow_by(y0, 40)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r00(x, s)
|
||||
x = self.conv_r01(x, s)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_by(x0, 48)
|
||||
x += x0
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class StyledVNet_distill(nn.Module):
|
||||
def __init__(self, style_size, in_chan, out_chan, num_filt=32, alpha=[1,1,1,1], bypass=None, **kwargs):
|
||||
"""V-Net like network with styles
|
||||
|
||||
See `vnet.VNet`.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
print(f'Number of filters = {num_filt}')
|
||||
if alpha is None:
|
||||
alpha = [1,1,1,1]
|
||||
print(f'alpha = {alpha}')
|
||||
|
||||
# activate non-identity skip connection in residual block
|
||||
# by explicitly setting out_chan
|
||||
self.conv_l00 = ResStyledBlock(style_size, in_chan, num_filt*alpha[0], seq='CACA')
|
||||
self.conv_l01 = ResStyledBlock(style_size, num_filt*alpha[0], num_filt*alpha[1], seq='CACA')
|
||||
self.down_l0 = ConvStyledBlock(style_size, num_filt*alpha[1], seq='DA')
|
||||
self.conv_l1 = ResStyledBlock(style_size, num_filt*alpha[1], num_filt*alpha[2], seq='CACA')
|
||||
self.down_l1 = ConvStyledBlock(style_size, num_filt*alpha[2], seq='DA')
|
||||
self.conv_l2 = ResStyledBlock(style_size, num_filt*alpha[2], num_filt*alpha[3], seq='CACA')
|
||||
self.down_l2 = ConvStyledBlock(style_size, num_filt*alpha[3], seq='DA')
|
||||
|
||||
self.conv_c = ResStyledBlock(style_size, num_filt*alpha[3], num_filt*alpha[3], seq='CACA')
|
||||
|
||||
self.up_r2 = ConvStyledBlock(style_size, num_filt*alpha[3], seq='UA')
|
||||
self.conv_r2 = ResStyledBlock(style_size, num_filt*2*alpha[3], num_filt*alpha[3], seq='CACA')
|
||||
self.up_r1 = ConvStyledBlock(style_size, num_filt*alpha[2], seq='UA')
|
||||
self.conv_r1 = ResStyledBlock(style_size, num_filt*2*alpha[2], num_filt*alpha[2], seq='CACA')
|
||||
self.up_r0 = ConvStyledBlock(style_size, num_filt*alpha[1], seq='UA')
|
||||
self.conv_r00 = ResStyledBlock(style_size, num_filt*2*alpha[1], num_filt*alpha[1], seq='CACA')
|
||||
self.conv_r01 = ResStyledBlock(style_size, num_filt*alpha[0], out_chan, seq='CAC')
|
||||
|
||||
if bypass is None:
|
||||
self.bypass = in_chan == out_chan
|
||||
else:
|
||||
self.bypass = bypass
|
||||
|
||||
def forward(self, x, s):
|
||||
if self.bypass:
|
||||
x0 = x
|
||||
|
||||
x = self.conv_l00(x, s)
|
||||
y0 = self.conv_l01(x, s)
|
||||
x = self.down_l0(y0, s)
|
||||
|
||||
y1 = self.conv_l1(x, s)
|
||||
x = self.down_l1(y1, s)
|
||||
|
||||
y2 = self.conv_l2(x, s)
|
||||
x = self.down_l2(y2, s)
|
||||
|
||||
x = self.conv_c(x, s)
|
||||
|
||||
x = self.up_r2(x, s)
|
||||
y2 = narrow_by(y2, 4)
|
||||
x = torch.cat([y2, x], dim=1)
|
||||
del y2
|
||||
x = self.conv_r2(x, s)
|
||||
|
||||
x = self.up_r1(x, s)
|
||||
y1 = narrow_by(y1, 16)
|
||||
x = torch.cat([y1, x], dim=1)
|
||||
del y1
|
||||
x = self.conv_r1(x, s)
|
||||
|
||||
x = self.up_r0(x, s)
|
||||
y0 = narrow_by(y0, 40)
|
||||
x = torch.cat([y0, x], dim=1)
|
||||
del y0
|
||||
x = self.conv_r00(x, s)
|
||||
x = self.conv_r01(x, s)
|
||||
|
||||
if self.bypass:
|
||||
x0 = narrow_by(x0, 48)
|
||||
x += x0
|
||||
|
||||
return x
|
36
mgborg_emulator/utils.py
Normal file
36
mgborg_emulator/utils.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
import aquila_borg as borg
|
||||
import configparser
|
||||
|
||||
cons = borg.console()
|
||||
myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x))
|
||||
|
||||
def get_cosmopar(ini_file):
|
||||
"""
|
||||
Extract cosmological parameters from an ini file
|
||||
|
||||
Args:
|
||||
:ini_file (str): Path to the ini file
|
||||
|
||||
Returns:
|
||||
:cpar (borg.cosmo.CosmologicalParameters): Cosmological parameters
|
||||
"""
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(ini_file)
|
||||
|
||||
cpar = borg.cosmo.CosmologicalParameters()
|
||||
cpar.default()
|
||||
cpar.fnl = float(config['cosmology']['fnl'])
|
||||
cpar.omega_k = float(config['cosmology']['omega_k'])
|
||||
cpar.omega_m = float(config['cosmology']['omega_m'])
|
||||
cpar.omega_b = float(config['cosmology']['omega_b'])
|
||||
cpar.omega_q = float(config['cosmology']['omega_q'])
|
||||
cpar.h = float(config['cosmology']['h100'])
|
||||
cpar.sigma8 = float(config['cosmology']['sigma8'])
|
||||
cpar.n_s = float(config['cosmology']['n_s'])
|
||||
cpar.w = float(config['cosmology']['w'])
|
||||
cpar.wprime = float(config['cosmology']['wprime'])
|
||||
cpar.A_s = 1.e-9 * symbolic_pofk.linear.sigma8_to_As(
|
||||
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
|
||||
|
||||
return cpar
|
Loading…
Add table
Reference in a new issue