Remove unnecessary arguments --in-channels and --out-channels
This commit is contained in:
parent
f64b1e42e9
commit
0764a1006e
@ -16,10 +16,6 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def add_common_args(parser):
|
def add_common_args(parser):
|
||||||
parser.add_argument('--in-channels', type=int, required=True,
|
|
||||||
help='number of input channels')
|
|
||||||
parser.add_argument('--out-channels', type=int, required=True,
|
|
||||||
help='number of output or target channels')
|
|
||||||
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
||||||
'of normalization functions from data.norms')
|
'of normalization functions from data.norms')
|
||||||
parser.add_argument('--criterion', default='MSELoss',
|
parser.add_argument('--criterion', default='MSELoss',
|
||||||
|
@ -32,6 +32,9 @@ class FieldDataset(Dataset):
|
|||||||
assert len(self.in_files) == len(self.tgt_files), \
|
assert len(self.in_files) == len(self.tgt_files), \
|
||||||
'input and target sample sizes do not match'
|
'input and target sample sizes do not match'
|
||||||
|
|
||||||
|
self.in_channels = sum(np.load(f).shape[0] for f in self.in_files[0])
|
||||||
|
self.tgt_channels = sum(np.load(f).shape[0] for f in self.tgt_files[0])
|
||||||
|
|
||||||
if isinstance(pad_or_crop, int):
|
if isinstance(pad_or_crop, int):
|
||||||
pad_or_crop = (pad_or_crop,) * 6
|
pad_or_crop = (pad_or_crop,) * 6
|
||||||
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \
|
assert isinstance(pad_or_crop, tuple) and len(pad_or_crop) == 6, \
|
||||||
@ -46,6 +49,10 @@ class FieldDataset(Dataset):
|
|||||||
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
|
norms = [import_norm(norm) for norm in norms if isinstance(norm, str)]
|
||||||
self.norms = norms
|
self.norms = norms
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self):
|
||||||
|
return self.in_channels, self.tgt_channels
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.in_files)
|
return len(self.in_files)
|
||||||
|
|
||||||
|
@ -21,7 +21,9 @@ def test(args):
|
|||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = UNet(args.in_channels, args.out_channels)
|
in_channels, out_channels = test_dataset.channels
|
||||||
|
|
||||||
|
model = UNet(in_channels, out_channels)
|
||||||
criterion = torch.nn.__dict__[args.criterion]()
|
criterion = torch.nn.__dict__[args.criterion]()
|
||||||
|
|
||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
|
@ -80,7 +80,9 @@ def gpu_worker(local_rank, args):
|
|||||||
pin_memory=True
|
pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
model = UNet(args.in_channels, args.out_channels)
|
in_channels, out_channels = train_dataset.channels
|
||||||
|
|
||||||
|
model = UNet(in_channels, out_channels)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
model = DistributedDataParallel(model, device_ids=[args.device])
|
model = DistributedDataParallel(model, device_ids=[args.device])
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.dis \
|
--norms cosmology.dis \
|
||||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||||
--load-state best_model.pth
|
--load-state best_model.pth
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ srun m2m.py train \
|
|||||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.dis --augment \
|
--norms cosmology.dis --augment \
|
||||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||||
# --load-state checkpoint.pth
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ tgt_files="$files"
|
|||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.vel \
|
--norms cosmology.vel \
|
||||||
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
--batches 1 --loader-workers 0 --pad-or-crop 40 \
|
||||||
--load-state best_model.pth
|
--load-state best_model.pth
|
||||||
|
|
||||||
|
@ -42,7 +42,7 @@ srun m2m.py train \
|
|||||||
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
--train-tgt-patterns "$data_root_dir/$tgt_dir/$train_dirs/$tgt_files" \
|
||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.vel --augment \
|
--norms cosmology.vel --augment \
|
||||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||||
# --load-state checkpoint.pth
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user