import numpy as np import _pickle as cPickle import traceback #-------------------- class storage_layer: def __init__(self, inputshape = (1,1,1)): self.inputshape = inputshape self.outputshape = inputshape self.mean = np.zeros(inputshape + (8,)) self.var = np.zeros(inputshape + (8,8)) self.count = 0 def update_datum(self,z): self.mean = (self.count*self.mean + z)/(self.count+1) delta = z - self.mean for i in range(8): for j in range(8): self.var[:,:,:,i,j] = (self.count*self.var[:,:,:,i,j] + delta[:,:,:,i]*delta[:,:,:,j])/(self.count+1) self.count+=1 return 0 #-------------------- class base_storage_layer: def __init__(self, inputshape = (1,1,1)): self.inputshape = inputshape self.outputshape = inputshape self.mean = np.zeros(inputshape) self.var = np.zeros(inputshape) self.count = 0 def update_datum(self,z): self.mean = (self.count*self.mean + z)/(self.count+1) delta = z - self.mean self.var = (self.count*self.var + delta*delta)/(self.count+1) self.count+=1 return 0 #-------------------- class trainer: def __init__(self, gen_model): self.gen_model = gen_model self.nlevel = gen_model.nlevel self.model = self.construct(nlevel=self.nlevel) #set name from variable name. http://stackoverflow.com/questions/1690400/getting-an-instance-name-inside-class-init (filename,line_number,function_name,text)=traceback.extract_stack()[-2] def_name = text[:text.find('=')].strip() self.name = def_name print('Trainer model:',self.name) try: print('load initial state') self.load() except: print('save initial state') self.save() def construct(self,nlevel=2): trainer_model = [] inputshape = (1,1,1) trainer_model.append(base_storage_layer(inputshape)) for i in np.arange(nlevel): trainer_model.append(storage_layer(inputshape)) inputshape = tuple([(l * 2) for l in inputshape]) return trainer_model def transfer(self, silent=False): if(not silent): print('Train model...') for m, t in zip(self.gen_model.model[::-1][:-1], self.model[::-1][:-1]): M_z_inv = self.gen_model.get_matrix_inv(t.var) R_inv = M_z_inv[:,:,:,0:7,0:7] m.Minv = np.copy(R_inv) R = self.gen_model.get_matrix_inv(R_inv) for n in range(7): for l in range(7): m.a[:,:,:,n] += -M_z_inv[:,:,:,7,l]*R[:,:,:,l,n] m.w = self.gen_model.get_matrix_sqrt(R) m.b[:,:,:,0] = t.mean[:,:,:,0] - m.a[:,:,:,0]*t.mean[:,:,:,7] m.b[:,:,:,1] = t.mean[:,:,:,1] - m.a[:,:,:,1]*t.mean[:,:,:,7] m.b[:,:,:,2] = t.mean[:,:,:,2] - m.a[:,:,:,2]*t.mean[:,:,:,7] m.b[:,:,:,3] = t.mean[:,:,:,3] - m.a[:,:,:,3]*t.mean[:,:,:,7] m.b[:,:,:,4] = t.mean[:,:,:,4] - m.a[:,:,:,4]*t.mean[:,:,:,7] m.b[:,:,:,5] = t.mean[:,:,:,5] - m.a[:,:,:,5]*t.mean[:,:,:,7] m.b[:,:,:,6] = t.mean[:,:,:,6] - m.a[:,:,:,6]*t.mean[:,:,:,7] #update model parameters of base layer self.gen_model.model[0].b = self.model[0].mean self.gen_model.model[0].w = np.sqrt(self.model[0].var) print('test',self.gen_model.model[0].b) if( not silent): print('Training done') return 0 def train_single(self,x_train, silent=False): assert isinstance(self.gen_model.model, list), "gen_model must be a list" assert isinstance(self.model, list), "trainer_model must be a list" #work through reversed model list and ignore base layer if(not silent): print('Train model...') for m, t in zip(self.gen_model.model[::-1][:-1], self.model[::-1][:-1]): z = m.kernel.apply(x_train, direction='bwd') t.update_datum(z) x_train = z[:,:,:,7] self.model[0].update_datum(x_train) if( not silent): print('Training done') def save(self): """save class as self.name.txt""" file = open(self.name+'.txt','wb') file.write(cPickle.dumps(self.__dict__)) file.close() def load(self): """try load self.name.txt""" file = open(self.name+'.txt','rb') dataPickle = file.read() file.close() self.__dict__ = cPickle.loads(dataPickle)