Some minor fixes

This commit is contained in:
Deaglan Bartlett 2024-06-14 16:04:42 +02:00
parent 42be9de326
commit 2016ef0599
11 changed files with 460 additions and 31 deletions

View file

@ -4,7 +4,8 @@ import numpy as np
# Emulator imports
import torch
# from .model_architecture.cosmology import *
torch.cuda.empty_cache()
from networks import StyledVNet, StyledVNet_distill
# JAX set-up
@ -16,6 +17,7 @@ from jax.config import config as jax_config
jax_config.update("jax_enable_x64", True)
from utils import myprint
import utils
class Emulator(borg.forward.BaseForwardModel):
@ -85,7 +87,7 @@ class Emulator(borg.forward.BaseForwardModel):
# Step 2 - correct for particles that moved over the periodic boundary
if self.debug:
start_time = time.time()
disp_temp = correct_displacement_over_periodic_boundaries(disp,L=self.box.L[0],max_disp_1d=self.box.L[0]//2)
disp_temp = utils.correct_displacement_over_periodic_boundaries(disp,L=self.box.L[0],max_disp_1d=self.box.L[0]//2)
if self.debug:
myprint("Step 2 of forward pass took %s seconds" % (time.time() - start_time))
@ -102,7 +104,7 @@ class Emulator(borg.forward.BaseForwardModel):
# Step 4 - normalize
if self.debug:
start_time = time.time()
dis(dis_in)
utils.dis(dis_in)
if self.debug:
myprint("Step 4 of forward pass took %s seconds" % (time.time() - start_time))
@ -162,7 +164,7 @@ class Emulator(borg.forward.BaseForwardModel):
# Step 9 - undo the normalization
if self.debug:
start_time = time.time()
dis(dis_out,undo=True)
utils.dis(dis_out,undo=True)
if self.debug:
self.dis_out = dis_out
myprint("Step 9 of forward pass took %s seconds" % (time.time() - start_time))
@ -227,7 +229,7 @@ class Emulator(borg.forward.BaseForwardModel):
# reverse step 9
if self.debug:
start_time = time.time()
dis(ag,undo=False)
utils.dis(ag,undo=False)
if self.debug:
myprint("Reverse step 9 took %s seconds" % (time.time() - start_time))
@ -274,7 +276,7 @@ class Emulator(borg.forward.BaseForwardModel):
# reverse step 4
if self.debug:
start_time = time.time()
dis(ag,undo=True)
utils.dis(ag,undo=True)
if self.debug:
myprint("Reverse step 4 took %s seconds" % (time.time() - start_time))
@ -314,9 +316,9 @@ class Emulator(borg.forward.BaseForwardModel):
y = y * Ny / Ly
z = z * Nz / Lz
qx, ix = get_cell_coord(x)
qy, iy = get_cell_coord(y)
qz, iz = get_cell_coord(z)
qx, ix = utils.get_cell_coord(x)
qy, iy = utils.get_cell_coord(y)
qz, iz = utils.get_cell_coord(z)
ix = ix.astype(int)
iy = iy.astype(int)
@ -352,9 +354,9 @@ class Emulator(borg.forward.BaseForwardModel):
y = y * Ny / Ly
z = z * Nz / Lz
qx, ix = get_cell_coord(x)
qy, iy = get_cell_coord(y)
qz, iz = get_cell_coord(z)
qx, ix = utils.get_cell_coord(x)
qy, iy = utils.get_cell_coord(y)
qz, iz = utils.get_cell_coord(z)
ix = ix.astype(int)
iy = iy.astype(int)

View file

@ -7,9 +7,10 @@ import symbolic_pofk.linear
import jax
from functools import partial
import ast
import torch
import utils as utils
from utils import myprint
from utils import myprint, initial_pos
import forwards
import networks
@ -56,7 +57,7 @@ class GaussianLikelihood(borg.likelihood.BaseLikelihood):
# Initialise model parameters
model_params = {
'logfR0':float(config['model']['logfr0'])
'logfR0':float(config['model']['logfr0']),
'gauss_sigma':float(config['model']['gauss_sigma'])
}
self.fwd_param.setModelParams(model_params)
@ -134,7 +135,12 @@ class GaussianLikelihood(borg.likelihood.BaseLikelihood):
if self.run_type == 'data':
raise NotImplementedError
elif self.run_type == 'mock':
raise NotImplementedError
output_density = np.zeros(self.fwd.getOutputBoxModel().N)
self.fwd.forwardModel_v2(s_hat)
self.fwd.getDensityFinal(output_density)
state["BORG_final_density"][:] = output_density
output_density += np.random.normal(size=self.fwd.getOutputBoxModel().N) * self.getModelParam('nullforward', 'gauss_sigma')
state["mock"][:] = output_density
else:
raise NotImplementedError
@ -179,9 +185,6 @@ class GaussianLikelihood(borg.likelihood.BaseLikelihood):
output_density = np.zeros((N,N,N))
self.fwd.forwardModel_v2(s_hat)
self.fwd.getDensityFinal(output_density)
# Get velocity field
output_velocity = self.fwd_vel.getVelocityField()
self.delta = output_density
self.vel = output_velocity
@ -219,7 +222,7 @@ class GaussianLikelihood(borg.likelihood.BaseLikelihood):
self.fwd.adjointModel_v2(mygradient)
mygrad_hat = np.zeros(s_hat.shape, dtype=np.complex128)
self.fwd.getAdjointModel(mygrad_hat)
elf.fwd.clearAdjointGradient()
self.fwd.clearAdjointGradient()
return mygrad_hat
@ -263,6 +266,10 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
config.read(ini_file)
ai = float(config['model']['ai'])
af = float(config['model']['af'])
# Setup forward model
chain = borg.forward.ChainForwardModel(box)
chain.addModel(borg.forward.models.HermiticEnforcer(box))
# Cosmological parameters
if ini_file is None:
@ -270,10 +277,6 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
else:
cpar = utils.get_cosmopar(ini_file)
chain.setCosmoParams(cpar)
# Setup forward model
chain = borg.forward.ChainForwardModel(box)
chain.addModel(borg.forward.models.HermiticEnforcer(box))
# CLASS transfer function
chain @= borg.forward.model_lib.M_PRIMORDIAL_AS(box)
@ -322,10 +325,7 @@ def build_gravity_model(state: borg.likelihood.MarkovState, box: borg.forward.Bo
# Initialize model
model = getattr(networks, config['emulator']['architecture'])
# if use_distilled:
# model = networks.StyledVNet_distill(1,3,3,num_filt=num_filt)
# else:
# model = networks.StyledVNet(1,3,3)
model = model(int(config['emulator']['style_size']), int(config['emulator']['in_chan']), int(config['emulator']['out_chan']))
model.load_state_dict(emu_weights['model'])
use_float64 = config['emulator']['use_float64'].lower().strip() == 'true'
@ -505,7 +505,7 @@ def build_likelihood(state: borg.likelihood.MarkovState, info: borg.likelihood.L
myprint("Building likelihood")
myprint(chain.getCosmoParams())
boxm = chain.getBoxModel()
likelihood = VelocityBORGLikelihood(chain, fwd_param, fwd_vel, borg.getIniConfigurationFilename())
likelihood = VelocityBORGLikelihood(chain, fwd_param, borg.getIniConfigurationFilename())
return likelihood

View file

@ -1,8 +1,8 @@
import torch
import torch.nn as nn
from .styled_conv import ConvStyledBlock, ResStyledBlock
from .narrow import narrow_by
from map2map_tools.styled_conv import ConvStyledBlock, ResStyledBlock
from map2map_tools.narrow import narrow_by
class StyledVNet(nn.Module):
def __init__(self, style_size, in_chan, out_chan, bypass=None, **kwargs):

View file

@ -1,5 +1,9 @@
import aquila_borg as borg
import configparser
import numpy as np
import jax.numpy as jnp
from scipy.special import hyp2f1
import symbolic_pofk.linear
cons = borg.console()
myprint = lambda x: cons.print_std(x) if type(x) == str else cons.print_std(repr(x))
@ -34,3 +38,83 @@ def get_cosmopar(ini_file):
cpar.sigma8, cpar.omega_m, cpar.omega_b, cpar.h, cpar.n_s)
return cpar
def initial_pos(L,N,order="F"):
values = np.linspace(0,L,N+1)[:-1] #ensure LLC
xx,yy,zz = np.meshgrid(values,values,values)
if order=="F":
pos_mesh = np.vstack((yy.flatten(),xx.flatten(),zz.flatten())).T
if order=="C":
pos_mesh = np.vstack((zz.flatten(),xx.flatten(),yy.flatten())).T
return pos_mesh
def correct_displacement_over_periodic_boundaries(disp,L,max_disp_1d=125):
# Need to correct for positions moving over the periodic boundary
moved_over_bound = L - max_disp_1d
axis = ['x','y','z']
for i in [0,1,2]:
idx_sup, idx_sub = check(disp,L,moved_over_bound,max_disp_1d,i,axis)
# Correct positions
disp[:,i][idx_sup] -= L
disp[:,i][idx_sub] += L
check(disp,L,moved_over_bound,max_disp_1d,i,axis)
assert np.amin(disp[:,i]) >= -max_disp_1d and np.amax(disp[:,i]) <= max_disp_1d, "Particles outside allowed region"
return disp
def check(disp,L,moved_over_bound,max_disp_1d,i,axis):
idxsup = disp[:,i]>moved_over_bound
idx = np.abs(disp[:,i])<=max_disp_1d
idxsub = disp[:,i]<-moved_over_bound
sup = len(disp[:,i][idxsup])
did_not_cross_boundary = len(disp[:,i][idx])
sub = len(disp[:,i][idxsub])
if not sub+did_not_cross_boundary+sup == len(disp[:,i]):
myprint(f'Disp in {axis[i]} direction under -{moved_over_bound} Mpc/h is = '+str(sub))
myprint(f'|Disp| in {axis[i]} direction under {max_disp_1d} Mpc/h is = '+str(did_not_cross_boundary))
myprint(f'Disp in {axis[i]} direction over {moved_over_bound} Mpc/h is = '+str(sup))
myprint('These add up to: '+str(sub+did_not_cross_boundary+sup))
myprint(f"Should add up to: len(disp[:,i]) {len(disp[:,i])}")
myprint('\n')
assert sub+did_not_cross_boundary+sup == len(disp[:,i]), "Incorrect summation" # cannot lose/gain particles
return idxsup, idxsub
def dis(x, undo=False, z=0.0, dis_std=6.0, **kwargs):
dis_norm = dis_std * linear_D(z) # [Mpc/h]
if not undo:
dis_norm = 1 / dis_norm
x *= dis_norm
def linear_D(z, Om=0.31):
"""linear growth function for flat LambdaCDM, normalized to 1 at redshift zero
"""
OL = 1 - Om
a = 1 / (1+z)
return a * hyp2f1(1, 1/3, 11/6, - OL * a**3 / Om) \
/ hyp2f1(1, 1/3, 11/6, - OL / Om)
def get_cell_coord(x):
ix = jnp.floor(x)
return x - ix, ix
def get_cell_coord_np(x):
ix = np.floor(x)
return x - ix, ix