Cache data during the first epoch
This commit is contained in:
parent
01b0c8b514
commit
77710bc8a3
@ -72,11 +72,8 @@ class FieldDataset(Dataset):
|
|||||||
|
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
if self.cache:
|
if self.cache:
|
||||||
self.in_fields = []
|
self.in_fields = {}
|
||||||
self.tgt_fields = []
|
self.tgt_fields = {}
|
||||||
for idx in range(len(self.in_files)):
|
|
||||||
self.in_fields.append([np.load(f) for f in self.in_files[idx]])
|
|
||||||
self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]])
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.in_files) * self.tot_reps
|
return len(self.in_files) * self.tot_reps
|
||||||
@ -96,8 +93,14 @@ class FieldDataset(Dataset):
|
|||||||
#print(f'self.pad = {self.pad}')
|
#print(f'self.pad = {self.pad}')
|
||||||
|
|
||||||
if self.cache:
|
if self.cache:
|
||||||
in_fields = self.in_fields[idx]
|
try:
|
||||||
tgt_fields = self.tgt_fields[idx]
|
in_fields = self.in_fields[idx]
|
||||||
|
tgt_fields = self.tgt_fields[idx]
|
||||||
|
except KeyError:
|
||||||
|
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||||
|
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||||
|
self.in_fields[idx] = in_fields
|
||||||
|
self.tgt_fields[idx] = tgt_fields
|
||||||
else:
|
else:
|
||||||
in_fields = [np.load(f) for f in self.in_files[idx]]
|
in_fields = [np.load(f) for f in self.in_files[idx]]
|
||||||
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
tgt_fields = [np.load(f) for f in self.tgt_files[idx]]
|
||||||
|
@ -96,7 +96,7 @@ def gpu_worker(local_rank, args):
|
|||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
)
|
)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||||
factor=0.1, verbose=True)
|
factor=0.5, patience=2, verbose=True)
|
||||||
|
|
||||||
if args.load_state:
|
if args.load_state:
|
||||||
state = torch.load(args.load_state, map_location=args.device)
|
state = torch.load(args.load_state, map_location=args.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user