Remove lazy argument feeding from field dataset

This commit is contained in:
Yin Li 2020-02-07 09:49:27 -05:00
parent 291dfb24b3
commit 1cd34c2eed
3 changed files with 29 additions and 5 deletions

View File

@ -34,8 +34,7 @@ class FieldDataset(Dataset):
"""
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
cache=False, div_data=False, rank=None, world_size=None,
**kwargs):
cache=False, div_data=False, rank=None, world_size=None):
in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists))

View File

@ -13,8 +13,14 @@ def test(args):
test_dataset = FieldDataset(
in_patterns=args.test_in_patterns,
tgt_patterns=args.test_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
augment=False,
**vars(args),
crop=args.crop,
pad=args.pad,
scale_factor=args.scale_factor,
noise_chan=args.noise_chan,
cache=args.cache,
)
test_loader = DataLoader(
test_dataset,

View File

@ -48,7 +48,17 @@ def gpu_worker(local_rank, args):
train_dataset = FieldDataset(
in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns,
**vars(args),
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
augment=args.augment,
crop=args.crop,
pad=args.pad,
scale_factor=args.scale_factor,
noise_chan=args.noise_chan,
cache=args.cache,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,
)
if not args.div_data:
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
@ -68,8 +78,17 @@ def gpu_worker(local_rank, args):
val_dataset = FieldDataset(
in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
augment=False,
**{k: v for k, v in vars(args).items() if k != 'augment'},
crop=args.crop,
pad=args.pad,
scale_factor=args.scale_factor,
noise_chan=args.noise_chan,
cache=args.cache,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,
)
if not args.div_data:
#val_sampler = DistributedSampler(val_dataset, shuffle=False)