Some minor fixes
This commit is contained in:
parent
42be9de326
commit
2016ef0599
11 changed files with 460 additions and 31 deletions
|
@ -4,7 +4,8 @@ import numpy as np
|
|||
|
||||
# Emulator imports
|
||||
import torch
|
||||
# from .model_architecture.cosmology import *
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
from networks import StyledVNet, StyledVNet_distill
|
||||
|
||||
# JAX set-up
|
||||
|
@ -16,6 +17,7 @@ from jax.config import config as jax_config
|
|||
jax_config.update("jax_enable_x64", True)
|
||||
|
||||
from utils import myprint
|
||||
import utils
|
||||
|
||||
class Emulator(borg.forward.BaseForwardModel):
|
||||
|
||||
|
@ -85,7 +87,7 @@ class Emulator(borg.forward.BaseForwardModel):
|
|||
# Step 2 - correct for particles that moved over the periodic boundary
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
disp_temp = correct_displacement_over_periodic_boundaries(disp,L=self.box.L[0],max_disp_1d=self.box.L[0]//2)
|
||||
disp_temp = utils.correct_displacement_over_periodic_boundaries(disp,L=self.box.L[0],max_disp_1d=self.box.L[0]//2)
|
||||
if self.debug:
|
||||
myprint("Step 2 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
|
@ -102,7 +104,7 @@ class Emulator(borg.forward.BaseForwardModel):
|
|||
# Step 4 - normalize
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
dis(dis_in)
|
||||
utils.dis(dis_in)
|
||||
if self.debug:
|
||||
myprint("Step 4 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
||||
|
@ -162,7 +164,7 @@ class Emulator(borg.forward.BaseForwardModel):
|
|||
# Step 9 - undo the normalization
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
dis(dis_out,undo=True)
|
||||
utils.dis(dis_out,undo=True)
|
||||
if self.debug:
|
||||
self.dis_out = dis_out
|
||||
myprint("Step 9 of forward pass took %s seconds" % (time.time() - start_time))
|
||||
|
@ -227,7 +229,7 @@ class Emulator(borg.forward.BaseForwardModel):
|
|||
# reverse step 9
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
dis(ag,undo=False)
|
||||
utils.dis(ag,undo=False)
|
||||
if self.debug:
|
||||
myprint("Reverse step 9 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
|
@ -274,7 +276,7 @@ class Emulator(borg.forward.BaseForwardModel):
|
|||
# reverse step 4
|
||||
if self.debug:
|
||||
start_time = time.time()
|
||||
dis(ag,undo=True)
|
||||
utils.dis(ag,undo=True)
|
||||
if self.debug:
|
||||
myprint("Reverse step 4 took %s seconds" % (time.time() - start_time))
|
||||
|
||||
|
@ -314,9 +316,9 @@ class Emulator(borg.forward.BaseForwardModel):
|
|||
y = y * Ny / Ly
|
||||
z = z * Nz / Lz
|
||||
|
||||
qx, ix = get_cell_coord(x)
|
||||
qy, iy = get_cell_coord(y)
|
||||
qz, iz = get_cell_coord(z)
|
||||
qx, ix = utils.get_cell_coord(x)
|
||||
qy, iy = utils.get_cell_coord(y)
|
||||
qz, iz = utils.get_cell_coord(z)
|
||||
|
||||
ix = ix.astype(int)
|
||||
iy = iy.astype(int)
|
||||
|
@ -352,9 +354,9 @@ class Emulator(borg.forward.BaseForwardModel):
|
|||
y = y * Ny / Ly
|
||||
z = z * Nz / Lz
|
||||
|
||||
qx, ix = get_cell_coord(x)
|
||||
qy, iy = get_cell_coord(y)
|
||||
qz, iz = get_cell_coord(z)
|
||||
qx, ix = utils.get_cell_coord(x)
|
||||
qy, iy = utils.get_cell_coord(y)
|
||||
qz, iz = utils.get_cell_coord(z)
|
||||
|
||||
ix = ix.astype(int)
|
||||
iy = iy.astype(int)
|
||||
|
|
|
@ -7,9 +7,10 @@ import symbolic_pofk.linear
|
|||
import jax
|
||||
from functools import partial
|
||||
import ast
|
||||
import torch
|
||||
|
||||
import utils as utils
|
||||
from utils import myprint
|
||||
from utils import myprint, initial_pos
|
||||
import forwards
|
||||
import networks
|
||||
|
||||
|
@ -56,7 +57,7 @@ class GaussianLikelihood(borg.likelihood.BaseLikelihood):
|
|||
|
||||
# Initialise model parameters
|
||||
model_params = {
|
||||
'logfR0':float(config['model']['logfr0'])
|
||||
'logfR0':float(config['model']['logfr0']),
|
||||
'gauss_sigma':float(config['model']['gauss_sigma'])
|
||||
}
|
||||
self.fwd_param.setModelParams(model_params)
|
||||
|
@ -134,7 +135,12 @@ class GaussianLikelihood(borg.likelihood.BaseLikelihood):
|
|||
if self.run_type == 'data':
|
||||
raise NotImplementedError
|
||||
elif self.run_type == 'mock':
|
||||
raise NotImplementedError
|
||||
output_density = np.zeros(self.fwd.getOutputBoxModel().N)
|
||||
self.fwd.forwardModel_v2(s_hat)
|
||||
self.fwd.getDensityFinal(output_density)
|
||||
state["BORG_final_density"][:] = output_density
|
||||
output_density += np.random.normal(size=self.fwd.getOutputBoxModel().N) * self.getModelParam('nullforward', 'gauss_sigma')
|
||||
state["mock"][:] = output_density
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -179,9 +185,6 @@ class GaussianLikelihood(borg.likelihood.BaseLikelihood):
|
|||
output_density = np.zeros((N,N,N))
|
||||
self.fwd.forwardModel_v2(s_hat)
|
||||
self.fwd.getDensityFinal(output_density)
|
||||
|
||||
# Get velocity field
|
||||
output_velocity = self.fwd_vel.getVelocityField()
|
||||
|
||||
self.delta = output_density
|
||||
self.vel = output_velocity
|
||||
|
@ -219,7 +222,7 @@ class GaussianLikelihood(borg.likelihood.BaseLikelihood):
|
|||
self.fwd.adjointModel_v2(mygradient)
|
||||
mygrad_hat = np.zeros(s_hat.shape, dtype=np.complex128)
|
||||
self.fwd.getAdjointModel(mygrad_hat)
|
||||
elf.fwd.clearAdjointGradient()
|
||||
self.fwd.clearAdjointGradient()
|
||||
|
||||
return mygrad_hat
|
||||
|
||||
|
@ -263,6 +266,10 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
|
|||
config.read(ini_file)
|
||||
ai = float(config['model']['ai'])
|
||||
af = float(config['model']['af'])
|
||||
|
||||
# Setup forward model
|
||||
chain = borg.forward.ChainForwardModel(box)
|
||||
chain.addModel(borg.forward.models.HermiticEnforcer(box))
|
||||
|
||||
# Cosmological parameters
|
||||
if ini_file is None:
|
||||
|
@ -270,10 +277,6 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
|
|||
else:
|
||||
cpar = utils.get_cosmopar(ini_file)
|
||||
chain.setCosmoParams(cpar)
|
||||
|
||||
# Setup forward model
|
||||
chain = borg.forward.ChainForwardModel(box)
|
||||
chain.addModel(borg.forward.models.HermiticEnforcer(box))
|
||||
|
||||
# CLASS transfer function
|
||||
chain @= borg.forward.model_lib.M_PRIMORDIAL_AS(box)
|
||||
|
@ -322,10 +325,7 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
|
|||
|
||||
# Initialize model
|
||||
model = getattr(networks, config['emulator']['architecture'])
|
||||
# if use_distilled:
|
||||
# model = networks.StyledVNet_distill(1,3,3,num_filt=num_filt)
|
||||
# else:
|
||||
# model = networks.StyledVNet(1,3,3)
|
||||
model = model(int(config['emulator']['style_size']), int(config['emulator']['in_chan']), int(config['emulator']['out_chan']))
|
||||
model.load_state_dict(emu_weights['model'])
|
||||
|
||||
use_float64 = config['emulator']['use_float64'].lower().strip() == 'true'
|
||||
|
@ -505,7 +505,7 @@ def build_likelihood(state: borg.likelihood.MarkovState, info: borg.likelihood.L
|
|||
myprint("Building likelihood")
|
||||
myprint(chain.getCosmoParams())
|
||||
boxm = chain.getBoxModel()
|
||||
likelihood = VelocityBORGLikelihood(chain, fwd_param, fwd_vel, borg.getIniConfigurationFilename())
|
||||
likelihood = VelocityBORGLikelihood(chain, fwd_param, borg.getIniConfigurationFilename())
|
||||
return likelihood
|
||||
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .styled_conv import ConvStyledBlock, ResStyledBlock
|
||||
from .narrow import narrow_by
|
||||
from map2map_tools.styled_conv import ConvStyledBlock, ResStyledBlock
|
||||
from map2map_tools.narrow import narrow_by
|
||||
|
||||
class StyledVNet(nn.Module):
|
||||
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import aquila_borg as borg
|
||||
import configparser
|
||||
import numpy as np
|
||||
import jax.numpy as jnp
|
||||
from scipy.special import hyp2f1
|
||||
import symbolic_pofk.linear
|
||||
|
||||
cons = borg.console()
|
||||
myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x))
|
||||
|
@ -34,3 +38,83 @@ def get_cosmopar(ini_file):
|
|||
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
|
||||
|
||||
return cpar
|
||||
|
||||
def initial_pos(L,N,order="F"):
|
||||
values = np.linspace(0,L,N+1)[:-1] #ensure LLC
|
||||
xx,yy,zz = np.meshgrid(values,values,values)
|
||||
|
||||
if order=="F":
|
||||
pos_mesh = np.vstack((yy.flatten(),xx.flatten(),zz.flatten())).T
|
||||
if order=="C":
|
||||
pos_mesh = np.vstack((zz.flatten(),xx.flatten(),yy.flatten())).T
|
||||
|
||||
return pos_mesh
|
||||
|
||||
|
||||
def correct_displacement_over_periodic_boundaries(disp,L,max_disp_1d=125):
|
||||
# Need to correct for positions moving over the periodic boundary
|
||||
|
||||
moved_over_bound = L - max_disp_1d
|
||||
axis = ['x','y','z']
|
||||
|
||||
for i in [0,1,2]:
|
||||
idx_sup, idx_sub = check(disp,L,moved_over_bound,max_disp_1d,i,axis)
|
||||
|
||||
# Correct positions
|
||||
disp[:,i][idx_sup] -= L
|
||||
disp[:,i][idx_sub] += L
|
||||
|
||||
check(disp,L,moved_over_bound,max_disp_1d,i,axis)
|
||||
|
||||
assert np.amin(disp[:,i]) >= -max_disp_1d and np.amax(disp[:,i]) <= max_disp_1d, "Particles outside allowed region"
|
||||
|
||||
return disp
|
||||
|
||||
|
||||
def check(disp,L,moved_over_bound,max_disp_1d,i,axis):
|
||||
idxsup = disp[:,i]>moved_over_bound
|
||||
idx = np.abs(disp[:,i])<=max_disp_1d
|
||||
idxsub = disp[:,i]<-moved_over_bound
|
||||
|
||||
sup = len(disp[:,i][idxsup])
|
||||
did_not_cross_boundary = len(disp[:,i][idx])
|
||||
sub = len(disp[:,i][idxsub])
|
||||
|
||||
if not sub+did_not_cross_boundary+sup == len(disp[:,i]):
|
||||
myprint(f'Disp in {axis[i]} direction under -{moved_over_bound} Mpc/h is = '+str(sub))
|
||||
myprint(f'|Disp| in {axis[i]} direction under {max_disp_1d} Mpc/h is = '+str(did_not_cross_boundary))
|
||||
myprint(f'Disp in {axis[i]} direction over {moved_over_bound} Mpc/h is = '+str(sup))
|
||||
myprint('These add up to: '+str(sub+did_not_cross_boundary+sup))
|
||||
myprint(f"Should add up to: len(disp[:,i]) {len(disp[:,i])}")
|
||||
myprint('\n')
|
||||
|
||||
assert sub+did_not_cross_boundary+sup == len(disp[:,i]), "Incorrect summation" # cannot lose/gain particles
|
||||
|
||||
return idxsup, idxsub
|
||||
|
||||
|
||||
def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
|
||||
dis_norm = dis_std * linear_D(z) # [Mpc/h]
|
||||
|
||||
if not undo:
|
||||
dis_norm = 1 / dis_norm
|
||||
|
||||
x *= dis_norm
|
||||
|
||||
|
||||
def linear_D(z, Om=0.31):
|
||||
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero
|
||||
"""
|
||||
OL = 1 - Om
|
||||
a = 1 / (1+z)
|
||||
return a * hyp2f1(1, 1/3, 11/6, - OL * a**3 / Om) \
|
||||
/ hyp2f1(1, 1/3, 11/6, - OL / Om)
|
||||
|
||||
|
||||
def get_cell_coord(x):
|
||||
ix = jnp.floor(x)
|
||||
return x - ix, ix
|
||||
|
||||
def get_cell_coord_np(x):
|
||||
ix = np.floor(x)
|
||||
return x - ix, ix
|
Loading…
Add table
Add a link
Reference in a new issue