ICOD/models/trainer.py

159 lines
5.0 KiB
Python
Raw Normal View History

2024-06-10 11:49:02 +02:00
import numpy as np
import _pickle as cPickle
import traceback
#--------------------
class storage_layer:
def __init__(self, inputshape = (1,1,1)):
self.inputshape = inputshape
self.outputshape = inputshape
self.mean = np.zeros(inputshape + (8,))
self.var = np.zeros(inputshape + (8,8))
self.count = 0
def update_datum(self,z):
self.mean = (self.count*self.mean + z)/(self.count+1)
delta = z - self.mean
for i in range(8):
for j in range(8):
self.var[:,:,:,i,j] = (self.count*self.var[:,:,:,i,j] + delta[:,:,:,i]*delta[:,:,:,j])/(self.count+1)
self.count+=1
return 0
#--------------------
class base_storage_layer:
def __init__(self, inputshape = (1,1,1)):
self.inputshape = inputshape
self.outputshape = inputshape
self.mean = np.zeros(inputshape)
self.var = np.zeros(inputshape)
self.count = 0
def update_datum(self,z):
self.mean = (self.count*self.mean + z)/(self.count+1)
delta = z - self.mean
self.var = (self.count*self.var + delta*delta)/(self.count+1)
self.count+=1
return 0
#--------------------
class trainer:
def __init__(self, gen_model):
self.gen_model = gen_model
self.nlevel = gen_model.nlevel
self.model = self.construct(nlevel=self.nlevel)
#set name from variable name. http://stackoverflow.com/questions/1690400/getting-an-instance-name-inside-class-init
(filename,line_number,function_name,text)=traceback.extract_stack()[-2]
def_name = text[:text.find('=')].strip()
self.name = def_name
print('Trainer model:',self.name)
try:
print('load initial state')
self.load()
except:
print('save initial state')
self.save()
def construct(self,nlevel=2):
trainer_model = []
inputshape = (1,1,1)
trainer_model.append(base_storage_layer(inputshape))
for i in np.arange(nlevel):
trainer_model.append(storage_layer(inputshape))
inputshape = tuple([(l * 2) for l in inputshape])
return trainer_model
def transfer(self, silent=False):
if(not silent):
print('Train model...')
for m, t in zip(self.gen_model.model[::-1][:-1], self.model[::-1][:-1]):
M_z_inv = self.gen_model.get_matrix_inv(t.var)
R_inv = M_z_inv[:,:,:,0:7,0:7]
m.Minv = np.copy(R_inv)
R = self.gen_model.get_matrix_inv(R_inv)
for n in range(7):
for l in range(7):
m.a[:,:,:,n] += -M_z_inv[:,:,:,7,l]*R[:,:,:,l,n]
m.w = self.gen_model.get_matrix_sqrt(R)
m.b[:,:,:,0] = t.mean[:,:,:,0] - m.a[:,:,:,0]*t.mean[:,:,:,7]
m.b[:,:,:,1] = t.mean[:,:,:,1] - m.a[:,:,:,1]*t.mean[:,:,:,7]
m.b[:,:,:,2] = t.mean[:,:,:,2] - m.a[:,:,:,2]*t.mean[:,:,:,7]
m.b[:,:,:,3] = t.mean[:,:,:,3] - m.a[:,:,:,3]*t.mean[:,:,:,7]
m.b[:,:,:,4] = t.mean[:,:,:,4] - m.a[:,:,:,4]*t.mean[:,:,:,7]
m.b[:,:,:,5] = t.mean[:,:,:,5] - m.a[:,:,:,5]*t.mean[:,:,:,7]
m.b[:,:,:,6] = t.mean[:,:,:,6] - m.a[:,:,:,6]*t.mean[:,:,:,7]
#update model parameters of base layer
self.gen_model.model[0].b = self.model[0].mean
self.gen_model.model[0].w = np.sqrt(self.model[0].var)
2024-06-10 13:53:47 +02:00
#print('test',self.gen_model.model[0].b)
2024-06-10 11:49:02 +02:00
if( not silent):
print('Training done')
return 0
def train_single(self,x_train, silent=False):
assert isinstance(self.gen_model.model, list), "gen_model must be a list"
assert isinstance(self.model, list), "trainer_model must be a list"
#work through reversed model list and ignore base layer
if(not silent):
print('Train model...')
for m, t in zip(self.gen_model.model[::-1][:-1], self.model[::-1][:-1]):
z = m.kernel.apply(x_train, direction='bwd')
t.update_datum(z)
x_train = z[:,:,:,7]
self.model[0].update_datum(x_train)
if( not silent):
print('Training done')
def save(self):
"""save class as self.name.txt"""
file = open(self.name+'.txt','wb')
file.write(cPickle.dumps(self.__dict__))
file.close()
def load(self):
"""try load self.name.txt"""
file = open(self.name+'.txt','rb')
dataPickle = file.read()
file.close()
self.__dict__ = cPickle.loads(dataPickle)