Remove unnecessary arguments --in-channels and --out-channels

This commit is contained in:
Yin Li 2019-12-09 10:19:21 -05:00
parent f64b1e42e9
commit 0764a1006e
8 changed files with 17 additions and 10 deletions

View File

@ -16,10 +16,6 @@ def get_args():
def add_common_args(parser): def add_common_args(parser):
parser.add_argument('--in-channels', type=int, required=True,
help='number of input channels')
parser.add_argument('--out-channels', type=int, required=True,
help='number of output or target channels')
parser.add_argument('--norms', type=str_list, help='comma-sep. list ' parser.add_argument('--norms', type=str_list, help='comma-sep. list '
'of normalization functions from data.norms') 'of normalization functions from data.norms')
parser.add_argument('--criterion', default='MSELoss', parser.add_argument('--criterion', default='MSELoss',

View File

@ -32,6 +32,9 @@ class FieldDataset(Dataset):
assert len(self.in_files) == len(self.tgt_files), \ assert len(self.in_files) == len(self.tgt_files), \
'input and target sample sizes do not match' 'input and target sample sizes do not match'
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
if isinstance(pad_or_crop, int): if isinstance(pad_or_crop, int):
pad_or_crop = (pad_or_crop,) * 6 pad_or_crop = (pad_or_crop,) * 6
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \ assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \
@ -46,6 +49,10 @@ class FieldDataset(Dataset):
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)] norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
self.norms = norms self.norms = norms
@property
def channels(self):
return self.in_channels, self.tgt_channels
def __len__(self): def __len__(self):
return len(self.in_files) return len(self.in_files)

View File

@ -21,7 +21,9 @@ def test(args):
num_workers=args.loader_workers, num_workers=args.loader_workers,
) )
model = UNet(args.in_channels, args.out_channels) in_channels, out_channels = test_dataset.channels
model = UNet(in_channels, out_channels)
criterion = torch.nn.__dict__[args.criterion]() criterion = torch.nn.__dict__[args.criterion]()
device = torch.device('cpu') device = torch.device('cpu')

View File

@ -80,7 +80,9 @@ def gpu_worker(local_rank, args):
pin_memory=True pin_memory=True
) )
model = UNet(args.in_channels, args.out_channels) in_channels, out_channels = train_dataset.channels
model = UNet(in_channels, out_channels)
model.to(args.device) model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device]) model = DistributedDataParallel(model, device_ids=[args.device])

View File

@ -37,7 +37,7 @@ tgt_files="$files"
m2m.py test \ m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--in-channels 3 --out-channels 3 --norms cosmology.dis \ --norms cosmology.dis \
--batches 1 --loader-workers 0 --pad-or-crop 40 \ --batches 1 --loader-workers 0 --pad-or-crop 40 \
--load-state best_model.pth --load-state best_model.pth

View File

@ -42,7 +42,7 @@ srun m2m.py train \
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \ --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--in-channels 3 --out-channels 3 --norms cosmology.dis --augment \ --norms cosmology.dis --augment \
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001 --epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
# --load-state checkpoint.pth # --load-state checkpoint.pth

View File

@ -37,7 +37,7 @@ tgt_files="$files"
m2m.py test \ m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \ --test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \ --test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--in-channels 3 --out-channels 3 --norms cosmology.vel \ --norms cosmology.vel \
--batches 1 --loader-workers 0 --pad-or-crop 40 \ --batches 1 --loader-workers 0 --pad-or-crop 40 \
--load-state best_model.pth --load-state best_model.pth

View File

@ -42,7 +42,7 @@ srun m2m.py train \
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \ --train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
--in-channels 3 --out-channels 3 --norms cosmology.vel --augment \ --norms cosmology.vel --augment \
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001 --epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
# --load-state checkpoint.pth # --load-state checkpoint.pth