Remove unnecessary arguments --in-channels and --out-channels

This commit is contained in:
Yin Li 2019-12-09 10:19:21 -05:00
parent f64b1e42e9
commit 0764a1006e
8 changed files with 17 additions and 10 deletions

View file

@ -16,10 +16,6 @@ def get_args():
def add_common_args(parser):
parser.add_argument('--in-channels', type=int, required=True,
help='number of input channels')
parser.add_argument('--out-channels', type=int, required=True,
help='number of output or target channels')
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
'of normalization functions from data.norms')
parser.add_argument('--criterion', default='MSELoss',

View file

@ -32,6 +32,9 @@ class FieldDataset(Dataset):
assert len(self.in_files) == len(self.tgt_files), \
'input and target sample sizes do not match'
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
if isinstance(pad_or_crop, int):
pad_or_crop = (pad_or_crop,) * 6
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \
@ -46,6 +49,10 @@ class FieldDataset(Dataset):
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
self.norms = norms
@property
def channels(self):
return self.in_channels, self.tgt_channels
def __len__(self):
return len(self.in_files)

View file

@ -21,7 +21,9 @@ def test(args):
num_workers=args.loader_workers,
)
model = UNet(args.in_channels, args.out_channels)
in_channels, out_channels = test_dataset.channels
model = UNet(in_channels, out_channels)
criterion = torch.nn.__dict__[args.criterion]()
device = torch.device('cpu')

View file

@ -80,7 +80,9 @@ def gpu_worker(local_rank, args):
pin_memory=True
)
model = UNet(args.in_channels, args.out_channels)
in_channels, out_channels = train_dataset.channels
model = UNet(in_channels, out_channels)
model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device])