Cache data during the first epoch

This commit is contained in:
Yin Li 2019-12-18 17:51:11 -05:00
parent 01b0c8b514
commit 77710bc8a3
2 changed files with 11 additions and 8 deletions

View File

@ -72,11 +72,8 @@ class FieldDataset(Dataset):
self.cache = cache
if self.cache:
self.in_fields = []
self.tgt_fields = []
for idx in range(len(self.in_files)):
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
self.in_fields = {}
self.tgt_fields = {}
def __len__(self):
return len(self.in_files) * self.tot_reps
@ -96,8 +93,14 @@ class FieldDataset(Dataset):
#print(f'self.pad = {self.pad}')
if self.cache:
try:
in_fields = self.in_fields[idx]
tgt_fields = self.tgt_fields[idx]
except KeyError:
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
self.in_fields[idx] = in_fields
self.tgt_fields[idx] = tgt_fields
else:
in_fields = [np.load(f) for f in self.in_files[idx]]
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]

View File

@ -96,7 +96,7 @@ def gpu_worker(local_rank, args):
weight_decay=args.weight_decay,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.1, verbose=True)
factor=0.5, patience=2, verbose=True)
if args.load_state:
state = torch.load(args.load_state, map_location=args.device)