2024-06-10 11:49:02 +02:00
|
|
|
import numpy as np
|
|
|
|
import _pickle as cPickle
|
|
|
|
import traceback
|
|
|
|
|
|
|
|
#--------------------
|
|
|
|
class storage_layer:
|
|
|
|
def __init__(self, inputshape = (1,1,1)):
|
|
|
|
|
|
|
|
self.inputshape = inputshape
|
|
|
|
self.outputshape = inputshape
|
|
|
|
|
|
|
|
self.mean = np.zeros(inputshape + (8,))
|
|
|
|
self.var = np.zeros(inputshape + (8,8))
|
|
|
|
|
|
|
|
self.count = 0
|
|
|
|
|
|
|
|
def update_datum(self,z):
|
|
|
|
|
|
|
|
self.mean = (self.count*self.mean + z)/(self.count+1)
|
|
|
|
|
|
|
|
delta = z - self.mean
|
|
|
|
for i in range(8):
|
|
|
|
for j in range(8):
|
|
|
|
self.var[:,:,:,i,j] = (self.count*self.var[:,:,:,i,j] + delta[:,:,:,i]*delta[:,:,:,j])/(self.count+1)
|
|
|
|
|
|
|
|
self.count+=1
|
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
#--------------------
|
|
|
|
class base_storage_layer:
|
|
|
|
def __init__(self, inputshape = (1,1,1)):
|
|
|
|
|
|
|
|
self.inputshape = inputshape
|
|
|
|
self.outputshape = inputshape
|
|
|
|
|
|
|
|
self.mean = np.zeros(inputshape)
|
|
|
|
self.var = np.zeros(inputshape)
|
|
|
|
|
|
|
|
self.count = 0
|
|
|
|
|
|
|
|
def update_datum(self,z):
|
|
|
|
|
|
|
|
self.mean = (self.count*self.mean + z)/(self.count+1)
|
|
|
|
|
|
|
|
delta = z - self.mean
|
|
|
|
|
|
|
|
self.var = (self.count*self.var + delta*delta)/(self.count+1)
|
|
|
|
|
|
|
|
self.count+=1
|
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
#--------------------
|
|
|
|
class trainer:
|
|
|
|
def __init__(self, gen_model):
|
|
|
|
self.gen_model = gen_model
|
|
|
|
self.nlevel = gen_model.nlevel
|
|
|
|
self.model = self.construct(nlevel=self.nlevel)
|
|
|
|
|
|
|
|
#set name from variable name. http://stackoverflow.com/questions/1690400/getting-an-instance-name-inside-class-init
|
|
|
|
(filename,line_number,function_name,text)=traceback.extract_stack()[-2]
|
|
|
|
def_name = text[:text.find('=')].strip()
|
|
|
|
self.name = def_name
|
|
|
|
|
|
|
|
print('Trainer model:',self.name)
|
|
|
|
|
|
|
|
try:
|
|
|
|
print('load initial state')
|
|
|
|
self.load()
|
|
|
|
except:
|
|
|
|
print('save initial state')
|
|
|
|
self.save()
|
|
|
|
|
|
|
|
def construct(self,nlevel=2):
|
|
|
|
trainer_model = []
|
|
|
|
inputshape = (1,1,1)
|
|
|
|
trainer_model.append(base_storage_layer(inputshape))
|
|
|
|
for i in np.arange(nlevel):
|
|
|
|
trainer_model.append(storage_layer(inputshape))
|
|
|
|
inputshape = tuple([(l * 2) for l in inputshape])
|
|
|
|
return trainer_model
|
|
|
|
|
|
|
|
def transfer(self, silent=False):
|
|
|
|
|
|
|
|
if(not silent):
|
|
|
|
print('Train model...')
|
|
|
|
|
|
|
|
for m, t in zip(self.gen_model.model[::-1][:-1], self.model[::-1][:-1]):
|
|
|
|
|
|
|
|
M_z_inv = self.gen_model.get_matrix_inv(t.var)
|
|
|
|
|
|
|
|
R_inv = M_z_inv[:,:,:,0:7,0:7]
|
|
|
|
m.Minv = np.copy(R_inv)
|
|
|
|
R = self.gen_model.get_matrix_inv(R_inv)
|
|
|
|
|
|
|
|
for n in range(7):
|
|
|
|
for l in range(7):
|
|
|
|
m.a[:,:,:,n] += -M_z_inv[:,:,:,7,l]*R[:,:,:,l,n]
|
|
|
|
|
|
|
|
m.w = self.gen_model.get_matrix_sqrt(R)
|
|
|
|
|
|
|
|
m.b[:,:,:,0] = t.mean[:,:,:,0] - m.a[:,:,:,0]*t.mean[:,:,:,7]
|
|
|
|
m.b[:,:,:,1] = t.mean[:,:,:,1] - m.a[:,:,:,1]*t.mean[:,:,:,7]
|
|
|
|
m.b[:,:,:,2] = t.mean[:,:,:,2] - m.a[:,:,:,2]*t.mean[:,:,:,7]
|
|
|
|
m.b[:,:,:,3] = t.mean[:,:,:,3] - m.a[:,:,:,3]*t.mean[:,:,:,7]
|
|
|
|
m.b[:,:,:,4] = t.mean[:,:,:,4] - m.a[:,:,:,4]*t.mean[:,:,:,7]
|
|
|
|
m.b[:,:,:,5] = t.mean[:,:,:,5] - m.a[:,:,:,5]*t.mean[:,:,:,7]
|
|
|
|
m.b[:,:,:,6] = t.mean[:,:,:,6] - m.a[:,:,:,6]*t.mean[:,:,:,7]
|
|
|
|
|
|
|
|
#update model parameters of base layer
|
|
|
|
self.gen_model.model[0].b = self.model[0].mean
|
|
|
|
self.gen_model.model[0].w = np.sqrt(self.model[0].var)
|
|
|
|
|
2024-06-10 13:53:47 +02:00
|
|
|
#print('test',self.gen_model.model[0].b)
|
2024-06-10 11:49:02 +02:00
|
|
|
|
|
|
|
if( not silent):
|
|
|
|
print('Training done')
|
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
def train_single(self,x_train, silent=False):
|
|
|
|
|
|
|
|
assert isinstance(self.gen_model.model, list), "gen_model must be a list"
|
|
|
|
assert isinstance(self.model, list), "trainer_model must be a list"
|
|
|
|
|
|
|
|
#work through reversed model list and ignore base layer
|
|
|
|
if(not silent):
|
|
|
|
print('Train model...')
|
|
|
|
|
|
|
|
for m, t in zip(self.gen_model.model[::-1][:-1], self.model[::-1][:-1]):
|
|
|
|
|
|
|
|
z = m.kernel.apply(x_train, direction='bwd')
|
|
|
|
|
|
|
|
t.update_datum(z)
|
|
|
|
|
|
|
|
x_train = z[:,:,:,7]
|
|
|
|
|
|
|
|
self.model[0].update_datum(x_train)
|
|
|
|
|
|
|
|
if( not silent):
|
|
|
|
print('Training done')
|
|
|
|
|
|
|
|
def save(self):
|
|
|
|
"""save class as self.name.txt"""
|
|
|
|
file = open(self.name+'.txt','wb')
|
|
|
|
file.write(cPickle.dumps(self.__dict__))
|
|
|
|
file.close()
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
"""try load self.name.txt"""
|
|
|
|
file = open(self.name+'.txt','rb')
|
|
|
|
dataPickle = file.read()
|
|
|
|
file.close()
|
|
|
|
|
|
|
|
self.__dict__ = cPickle.loads(dataPickle)
|
|
|
|
|
|
|
|
|