Change __dict__ to getattr
This commit is contained in:
parent
77710bc8a3
commit
848dc87169
7 changed files with 16 additions and 13 deletions
|
@ -4,7 +4,7 @@ from . import cosmology
|
|||
|
||||
|
||||
def import_norm(path):
|
||||
mod, func = path.rsplit('.', 1)
|
||||
mod, fun = path.rsplit('.', 1)
|
||||
mod = import_module('.' + mod, __name__)
|
||||
func = getattr(mod, func)
|
||||
return func
|
||||
fun = getattr(mod, fun)
|
||||
return fun
|
||||
|
|
|
@ -25,8 +25,10 @@ def test(args):
|
|||
|
||||
in_channels, out_channels = test_dataset.channels
|
||||
|
||||
model = models.__dict__[args.model](in_channels, out_channels)
|
||||
criterion = torch.nn.__dict__[args.criterion]()
|
||||
model = getattr(models, args.model)
|
||||
model = model(in_channels, out_channels)
|
||||
criterion = getattr(torch.nn, args.criterion)
|
||||
criterion = criterion()
|
||||
|
||||
device = torch.device('cpu')
|
||||
state = torch.load(args.load_state, map_location=device)
|
||||
|
|
|
@ -82,21 +82,24 @@ def gpu_worker(local_rank, args):
|
|||
|
||||
in_channels, out_channels = train_dataset.channels
|
||||
|
||||
model = models.__dict__[args.model](in_channels, out_channels)
|
||||
model = getattr(models, args.model)
|
||||
model = model(in_channels, out_channels)
|
||||
model.to(args.device)
|
||||
model = DistributedDataParallel(model, device_ids=[args.device])
|
||||
|
||||
criterion = torch.nn.__dict__[args.criterion]()
|
||||
criterion = getattr(torch.nn, args.criterion)
|
||||
criterion = criterion()
|
||||
criterion.to(args.device)
|
||||
|
||||
optimizer = torch.optim.__dict__[args.optimizer](
|
||||
optimizer = getattr(torch.optim, args.optimizer)
|
||||
optimizer = optimizer(
|
||||
model.parameters(),
|
||||
lr=args.lr,
|
||||
#momentum=args.momentum,
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||
factor=0.5, patience=2, verbose=True)
|
||||
factor=0.5, patience=3, verbose=True)
|
||||
|
||||
if args.load_state:
|
||||
state = torch.load(args.load_state, map_location=args.device)
|
||||
|
|
|
@ -35,7 +35,7 @@ tgt_files="$files"
|
|||
m2m.py test \
|
||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||
--norms cosmology.dis --crop 128 --pad 42 \
|
||||
--norms cosmology.dis --crop 256 --pad 42 \
|
||||
--model VNet \
|
||||
--load-state best_model.pth \
|
||||
--batches 1 --loader-workers 0 \
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
#SBATCH --job-name=dis2dis
|
||||
#SBATCH --dependency=singleton
|
||||
#SBATCH --output=%x-%j.out
|
||||
|
||||
#SBATCH --partition=gpu
|
||||
|
|
|
@ -35,7 +35,7 @@ tgt_files="$files"
|
|||
m2m.py test \
|
||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||
--norms cosmology.vel --crop 128 --pad 42 \
|
||||
--norms cosmology.vel --crop 256 --pad 42 \
|
||||
--model VNet \
|
||||
--load-state best_model.pth \
|
||||
--batches 1 --loader-workers 0 \
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#!/bin/bash
|
||||
|
||||
#SBATCH --job-name=vel2vel
|
||||
#SBATCH --dependency=singleton
|
||||
#SBATCH --output=%x-%j.out
|
||||
|
||||
#SBATCH --partition=gpu
|
||||
|
|
Loading…
Reference in a new issue