From 77710bc8a30cee37da65bdf91ebc469173d946e8 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Wed, 18 Dec 2019 17:51:11 -0500 Subject: [PATCH] Cache data during the first epoch --- map2map/data/fields.py | 17 ++++++++++------- map2map/train.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 2c68b30..ee073cb 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -72,11 +72,8 @@ class FieldDataset(Dataset): self.cache = cache if self.cache: - self.in_fields = [] - self.tgt_fields = [] - for idx in range(len(self.in_files)): - self.in_fields.append([np.load(f) for f in self.in_files[idx]]) - self.tgt_fields.append([np.load(f) for f in self.tgt_files[idx]]) + self.in_fields = {} + self.tgt_fields = {} def __len__(self): return len(self.in_files) * self.tot_reps @@ -96,8 +93,14 @@ class FieldDataset(Dataset): #print(f'self.pad = {self.pad}') if self.cache: - in_fields = self.in_fields[idx] - tgt_fields = self.tgt_fields[idx] + try: + in_fields = self.in_fields[idx] + tgt_fields = self.tgt_fields[idx] + except KeyError: + in_fields = [np.load(f) for f in self.in_files[idx]] + tgt_fields = [np.load(f) for f in self.tgt_files[idx]] + self.in_fields[idx] = in_fields + self.tgt_fields[idx] = tgt_fields else: in_fields = [np.load(f) for f in self.in_files[idx]] tgt_fields = [np.load(f) for f in self.tgt_files[idx]] diff --git a/map2map/train.py b/map2map/train.py index 11dfb4f..cfc2bd5 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -96,7 +96,7 @@ def gpu_worker(local_rank, args): weight_decay=args.weight_decay, ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, - factor=0.1, verbose=True) + factor=0.5, patience=2, verbose=True) if args.load_state: state = torch.load(args.load_state, map_location=args.device)