Remove lazy argument feeding from field dataset

This commit is contained in:
Yin Li 2020-02-07 09:49:27 -05:00
parent 291dfb24b3
commit 1cd34c2eed
3 changed files with 29 additions and 5 deletions

View File

@ -34,8 +34,7 @@ class FieldDataset(Dataset):
""" """
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None, def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0, augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
cache=False, div_data=False, rank=None, world_size=None, cache=False, div_data=False, rank=None, world_size=None):
**kwargs):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
self.in_files = list(zip(* in_file_lists)) self.in_files = list(zip(* in_file_lists))

View File

@ -13,8 +13,14 @@ def test(args):
test_dataset = FieldDataset( test_dataset = FieldDataset(
in_patterns=args.test_in_patterns, in_patterns=args.test_in_patterns,
tgt_patterns=args.test_tgt_patterns, tgt_patterns=args.test_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
augment=False, augment=False,
**vars(args), crop=args.crop,
pad=args.pad,
scale_factor=args.scale_factor,
noise_chan=args.noise_chan,
cache=args.cache,
) )
test_loader = DataLoader( test_loader = DataLoader(
test_dataset, test_dataset,

View File

@ -48,7 +48,17 @@ def gpu_worker(local_rank, args):
train_dataset = FieldDataset( train_dataset = FieldDataset(
in_patterns=args.train_in_patterns, in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns, tgt_patterns=args.train_tgt_patterns,
**vars(args), in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
augment=args.augment,
crop=args.crop,
pad=args.pad,
scale_factor=args.scale_factor,
noise_chan=args.noise_chan,
cache=args.cache,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,
) )
if not args.div_data: if not args.div_data:
#train_sampler = DistributedSampler(train_dataset, shuffle=True) #train_sampler = DistributedSampler(train_dataset, shuffle=True)
@ -68,8 +78,17 @@ def gpu_worker(local_rank, args):
val_dataset = FieldDataset( val_dataset = FieldDataset(
in_patterns=args.val_in_patterns, in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns, tgt_patterns=args.val_tgt_patterns,
in_norms=args.in_norms,
tgt_norms=args.tgt_norms,
augment=False, augment=False,
**{k: v for k, v in vars(args).items() if k != 'augment'}, crop=args.crop,
pad=args.pad,
scale_factor=args.scale_factor,
noise_chan=args.noise_chan,
cache=args.cache,
div_data=args.div_data,
rank=rank,
world_size=args.world_size,
) )
if not args.div_data: if not args.div_data:
#val_sampler = DistributedSampler(val_dataset, shuffle=False) #val_sampler = DistributedSampler(val_dataset, shuffle=False)