Remove lazy argument feeding from field dataset
This commit is contained in:
parent
291dfb24b3
commit
1cd34c2eed
@ -34,8 +34,7 @@ class FieldDataset(Dataset):
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
||||
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
|
||||
cache=False, div_data=False, rank=None, world_size=None,
|
||||
**kwargs):
|
||||
cache=False, div_data=False, rank=None, world_size=None):
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
self.in_files = list(zip(* in_file_lists))
|
||||
|
||||
|
@ -13,8 +13,14 @@ def test(args):
|
||||
test_dataset = FieldDataset(
|
||||
in_patterns=args.test_in_patterns,
|
||||
tgt_patterns=args.test_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
tgt_norms=args.tgt_norms,
|
||||
augment=False,
|
||||
**vars(args),
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
noise_chan=args.noise_chan,
|
||||
cache=args.cache,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
|
@ -48,7 +48,17 @@ def gpu_worker(local_rank, args):
|
||||
train_dataset = FieldDataset(
|
||||
in_patterns=args.train_in_patterns,
|
||||
tgt_patterns=args.train_tgt_patterns,
|
||||
**vars(args),
|
||||
in_norms=args.in_norms,
|
||||
tgt_norms=args.tgt_norms,
|
||||
augment=args.augment,
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
noise_chan=args.noise_chan,
|
||||
cache=args.cache,
|
||||
div_data=args.div_data,
|
||||
rank=rank,
|
||||
world_size=args.world_size,
|
||||
)
|
||||
if not args.div_data:
|
||||
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||
@ -68,8 +78,17 @@ def gpu_worker(local_rank, args):
|
||||
val_dataset = FieldDataset(
|
||||
in_patterns=args.val_in_patterns,
|
||||
tgt_patterns=args.val_tgt_patterns,
|
||||
in_norms=args.in_norms,
|
||||
tgt_norms=args.tgt_norms,
|
||||
augment=False,
|
||||
**{k: v for k, v in vars(args).items() if k != 'augment'},
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
noise_chan=args.noise_chan,
|
||||
cache=args.cache,
|
||||
div_data=args.div_data,
|
||||
rank=rank,
|
||||
world_size=args.world_size,
|
||||
)
|
||||
if not args.div_data:
|
||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
|
Loading…
Reference in New Issue
Block a user