Change __dict__ to getattr

This commit is contained in:
Yin Li 2019-12-23 16:04:35 -05:00
parent 77710bc8a3
commit 848dc87169
7 changed files with 16 additions and 13 deletions

View File

@ -4,7 +4,7 @@ from . import cosmology
def import_norm(path):
mod, func = path.rsplit('.', 1)
mod, fun = path.rsplit('.', 1)
mod = import_module('.' + mod, __name__)
func = getattr(mod, func)
return func
fun = getattr(mod, fun)
return fun

View File

@ -25,8 +25,10 @@ def test(args):
in_channels, out_channels = test_dataset.channels
model = models.__dict__[args.model](in_channels, out_channels)
criterion = torch.nn.__dict__[args.criterion]()
model = getattr(models, args.model)
model = model(in_channels, out_channels)
criterion = getattr(torch.nn, args.criterion)
criterion = criterion()
device = torch.device('cpu')
state = torch.load(args.load_state, map_location=device)

View File

@ -82,21 +82,24 @@ def gpu_worker(local_rank, args):
in_channels, out_channels = train_dataset.channels
model = models.__dict__[args.model](in_channels, out_channels)
model = getattr(models, args.model)
model = model(in_channels, out_channels)
model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device])
criterion = torch.nn.__dict__[args.criterion]()
criterion = getattr(torch.nn, args.criterion)
criterion = criterion()
criterion.to(args.device)
optimizer = torch.optim.__dict__[args.optimizer](
optimizer = getattr(torch.optim, args.optimizer)
optimizer = optimizer(
model.parameters(),
lr=args.lr,
#momentum=args.momentum,
weight_decay=args.weight_decay,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.5, patience=2, verbose=True)
factor=0.5, patience=3, verbose=True)
if args.load_state:
state = torch.load(args.load_state, map_location=args.device)

View File

@ -35,7 +35,7 @@ tgt_files="$files"
m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.dis --crop 128 --pad 42 \
--norms cosmology.dis --crop 256 --pad 42 \
--model VNet \
--load-state best_model.pth \
--batches 1 --loader-workers 0 \

View File

@ -1,7 +1,6 @@
#!/bin/bash
#SBATCH --job-name=dis2dis
#SBATCH --dependency=singleton
#SBATCH --output=%x-%j.out
#SBATCH --partition=gpu

View File

@ -35,7 +35,7 @@ tgt_files="$files"
m2m.py test \
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
--norms cosmology.vel --crop 128 --pad 42 \
--norms cosmology.vel --crop 256 --pad 42 \
--model VNet \
--load-state best_model.pth \
--batches 1 --loader-workers 0 \

View File

@ -1,7 +1,6 @@
#!/bin/bash
#SBATCH --job-name=vel2vel
#SBATCH --dependency=singleton
#SBATCH --output=%x-%j.out
#SBATCH --partition=gpu