Remove lazy argument feeding from field dataset
This commit is contained in:
parent
291dfb24b3
commit
1cd34c2eed
@ -34,8 +34,7 @@ class FieldDataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
||||||
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
|
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
|
||||||
cache=False, div_data=False, rank=None, world_size=None,
|
cache=False, div_data=False, rank=None, world_size=None):
|
||||||
**kwargs):
|
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
self.in_files = list(zip(* in_file_lists))
|
self.in_files = list(zip(* in_file_lists))
|
||||||
|
|
||||||
|
@ -13,8 +13,14 @@ def test(args):
|
|||||||
test_dataset = FieldDataset(
|
test_dataset = FieldDataset(
|
||||||
in_patterns=args.test_in_patterns,
|
in_patterns=args.test_in_patterns,
|
||||||
tgt_patterns=args.test_tgt_patterns,
|
tgt_patterns=args.test_tgt_patterns,
|
||||||
|
in_norms=args.in_norms,
|
||||||
|
tgt_norms=args.tgt_norms,
|
||||||
augment=False,
|
augment=False,
|
||||||
**vars(args),
|
crop=args.crop,
|
||||||
|
pad=args.pad,
|
||||||
|
scale_factor=args.scale_factor,
|
||||||
|
noise_chan=args.noise_chan,
|
||||||
|
cache=args.cache,
|
||||||
)
|
)
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
|
@ -48,7 +48,17 @@ def gpu_worker(local_rank, args):
|
|||||||
train_dataset = FieldDataset(
|
train_dataset = FieldDataset(
|
||||||
in_patterns=args.train_in_patterns,
|
in_patterns=args.train_in_patterns,
|
||||||
tgt_patterns=args.train_tgt_patterns,
|
tgt_patterns=args.train_tgt_patterns,
|
||||||
**vars(args),
|
in_norms=args.in_norms,
|
||||||
|
tgt_norms=args.tgt_norms,
|
||||||
|
augment=args.augment,
|
||||||
|
crop=args.crop,
|
||||||
|
pad=args.pad,
|
||||||
|
scale_factor=args.scale_factor,
|
||||||
|
noise_chan=args.noise_chan,
|
||||||
|
cache=args.cache,
|
||||||
|
div_data=args.div_data,
|
||||||
|
rank=rank,
|
||||||
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
@ -68,8 +78,17 @@ def gpu_worker(local_rank, args):
|
|||||||
val_dataset = FieldDataset(
|
val_dataset = FieldDataset(
|
||||||
in_patterns=args.val_in_patterns,
|
in_patterns=args.val_in_patterns,
|
||||||
tgt_patterns=args.val_tgt_patterns,
|
tgt_patterns=args.val_tgt_patterns,
|
||||||
|
in_norms=args.in_norms,
|
||||||
|
tgt_norms=args.tgt_norms,
|
||||||
augment=False,
|
augment=False,
|
||||||
**{k: v for k, v in vars(args).items() if k != 'augment'},
|
crop=args.crop,
|
||||||
|
pad=args.pad,
|
||||||
|
scale_factor=args.scale_factor,
|
||||||
|
noise_chan=args.noise_chan,
|
||||||
|
cache=args.cache,
|
||||||
|
div_data=args.div_data,
|
||||||
|
rank=rank,
|
||||||
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user