Remove unnecessary arguments --in-channels and --out-channels
This commit is contained in:
parent
f64b1e42e9
commit
0764a1006e
@ -16,10 +16,6 @@ def get_args():
|
||||
|
||||
|
||||
def add_common_args(parser):
|
||||
parser.add_argument('--in-channels', type=int, required=True,
|
||||
help='number of input channels')
|
||||
parser.add_argument('--out-channels', type=int, required=True,
|
||||
help='number of output or target channels')
|
||||
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
||||
'of normalization functions from data.norms')
|
||||
parser.add_argument('--criterion', default='MSELoss',
|
||||
|
@ -32,6 +32,9 @@ class FieldDataset(Dataset):
|
||||
assert len(self.in_files) == len(self.tgt_files), \
|
||||
'input and target sample sizes do not match'
|
||||
|
||||
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
|
||||
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
|
||||
|
||||
if isinstance(pad_or_crop, int):
|
||||
pad_or_crop = (pad_or_crop,) * 6
|
||||
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \
|
||||
@ -46,6 +49,10 @@ class FieldDataset(Dataset):
|
||||
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
|
||||
self.norms = norms
|
||||
|
||||
@property
|
||||
def channels(self):
|
||||
return self.in_channels, self.tgt_channels
|
||||
|
||||
def __len__(self):
|
||||
return len(self.in_files)
|
||||
|
||||
|
@ -21,7 +21,9 @@ def test(args):
|
||||
num_workers=args.loader_workers,
|
||||
)
|
||||
|
||||
model = UNet(args.in_channels, args.out_channels)
|
||||
in_channels, out_channels = test_dataset.channels
|
||||
|
||||
model = UNet(in_channels, out_channels)
|
||||
criterion = torch.nn.__dict__[args.criterion]()
|
||||
|
||||
device = torch.device('cpu')
|
||||
|
@ -80,7 +80,9 @@ def gpu_worker(local_rank, args):
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
model = UNet(args.in_channels, args.out_channels)
|
||||
in_channels, out_channels = train_dataset.channels
|
||||
|
||||
model = UNet(in_channels, out_channels)
|
||||
model.to(args.device)
|
||||
model = DistributedDataParallel(model, device_ids=[args.device])
|
||||
|
||||
|
@ -37,7 +37,7 @@ tgt_files="$files"
|
||||
m2m.py test \
|
||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||
--in-channels 3 --out-channels 3 --norms cosmology.dis \
|
||||
--norms cosmology.dis \
|
||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||
--load-state best_model.pth
|
||||
|
||||
|
@ -42,7 +42,7 @@ srun m2m.py train \
|
||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||
--in-channels 3 --out-channels 3 --norms cosmology.dis --augment \
|
||||
--norms cosmology.dis --augment \
|
||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||
# --load-state checkpoint.pth
|
||||
|
||||
|
@ -37,7 +37,7 @@ tgt_files="$files"
|
||||
m2m.py test \
|
||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||
--in-channels 3 --out-channels 3 --norms cosmology.vel \
|
||||
--norms cosmology.vel \
|
||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||
--load-state best_model.pth
|
||||
|
||||
|
@ -42,7 +42,7 @@ srun m2m.py train \
|
||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||
--in-channels 3 --out-channels 3 --norms cosmology.vel --augment \
|
||||
--norms cosmology.vel --augment \
|
||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||
# --load-state checkpoint.pth
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user