Cache data during the first epoch
This commit is contained in:
parent
01b0c8b514
commit
77710bc8a3
@ -72,11 +72,8 @@ class FieldDataset(Dataset):
|
||||
|
||||
self.cache = cache
|
||||
if self.cache:
|
||||
self.in_fields = []
|
||||
self.tgt_fields = []
|
||||
for idx in range(len(self.in_files)):
|
||||
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
|
||||
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
|
||||
self.in_fields = {}
|
||||
self.tgt_fields = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.in_files) * self.tot_reps
|
||||
@ -96,8 +93,14 @@ class FieldDataset(Dataset):
|
||||
#print(f'self.pad = {self.pad}')
|
||||
|
||||
if self.cache:
|
||||
in_fields = self.in_fields[idx]
|
||||
tgt_fields = self.tgt_fields[idx]
|
||||
try:
|
||||
in_fields = self.in_fields[idx]
|
||||
tgt_fields = self.tgt_fields[idx]
|
||||
except KeyError:
|
||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||
self.in_fields[idx] = in_fields
|
||||
self.tgt_fields[idx] = tgt_fields
|
||||
else:
|
||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||
|
@ -96,7 +96,7 @@ def gpu_worker(local_rank, args):
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||
factor=0.1, verbose=True)
|
||||
factor=0.5, patience=2, verbose=True)
|
||||
|
||||
if args.load_state:
|
||||
state = torch.load(args.load_state, map_location=args.device)
|
||||
|
Loading…
Reference in New Issue
Block a user