Change __dict__ to getattr
This commit is contained in:
parent
77710bc8a3
commit
848dc87169
7 changed files with 16 additions and 13 deletions
|
@ -4,7 +4,7 @@ from . import cosmology
|
||||||
|
|
||||||
|
|
||||||
def import_norm(path):
|
def import_norm(path):
|
||||||
mod, func = path.rsplit('.', 1)
|
mod, fun = path.rsplit('.', 1)
|
||||||
mod = import_module('.' + mod, __name__)
|
mod = import_module('.' + mod, __name__)
|
||||||
func = getattr(mod, func)
|
fun = getattr(mod, fun)
|
||||||
return func
|
return fun
|
||||||
|
|
|
@ -25,8 +25,10 @@ def test(args):
|
||||||
|
|
||||||
in_channels, out_channels = test_dataset.channels
|
in_channels, out_channels = test_dataset.channels
|
||||||
|
|
||||||
model = models.__dict__[args.model](in_channels, out_channels)
|
model = getattr(models, args.model)
|
||||||
criterion = torch.nn.__dict__[args.criterion]()
|
model = model(in_channels, out_channels)
|
||||||
|
criterion = getattr(torch.nn, args.criterion)
|
||||||
|
criterion = criterion()
|
||||||
|
|
||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
state = torch.load(args.load_state, map_location=device)
|
state = torch.load(args.load_state, map_location=device)
|
||||||
|
|
|
@ -82,21 +82,24 @@ def gpu_worker(local_rank, args):
|
||||||
|
|
||||||
in_channels, out_channels = train_dataset.channels
|
in_channels, out_channels = train_dataset.channels
|
||||||
|
|
||||||
model = models.__dict__[args.model](in_channels, out_channels)
|
model = getattr(models, args.model)
|
||||||
|
model = model(in_channels, out_channels)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
model = DistributedDataParallel(model, device_ids=[args.device])
|
model = DistributedDataParallel(model, device_ids=[args.device])
|
||||||
|
|
||||||
criterion = torch.nn.__dict__[args.criterion]()
|
criterion = getattr(torch.nn, args.criterion)
|
||||||
|
criterion = criterion()
|
||||||
criterion.to(args.device)
|
criterion.to(args.device)
|
||||||
|
|
||||||
optimizer = torch.optim.__dict__[args.optimizer](
|
optimizer = getattr(torch.optim, args.optimizer)
|
||||||
|
optimizer = optimizer(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
#momentum=args.momentum,
|
#momentum=args.momentum,
|
||||||
weight_decay=args.weight_decay,
|
weight_decay=args.weight_decay,
|
||||||
)
|
)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||||
factor=0.5, patience=2, verbose=True)
|
factor=0.5, patience=3, verbose=True)
|
||||||
|
|
||||||
if args.load_state:
|
if args.load_state:
|
||||||
state = torch.load(args.load_state, map_location=args.device)
|
state = torch.load(args.load_state, map_location=args.device)
|
||||||
|
|
|
@ -35,7 +35,7 @@ tgt_files="$files"
|
||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.dis --crop 128 --pad 42 \
|
--norms cosmology.dis --crop 256 --pad 42 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pth \
|
--load-state best_model.pth \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0 \
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
#SBATCH --job-name=dis2dis
|
#SBATCH --job-name=dis2dis
|
||||||
#SBATCH --dependency=singleton
|
|
||||||
#SBATCH --output=%x-%j.out
|
#SBATCH --output=%x-%j.out
|
||||||
|
|
||||||
#SBATCH --partition=gpu
|
#SBATCH --partition=gpu
|
||||||
|
|
|
@ -35,7 +35,7 @@ tgt_files="$files"
|
||||||
m2m.py test \
|
m2m.py test \
|
||||||
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
--test-in-patterns "$data_root_dir/$in_dir/$test_dirs/$in_files" \
|
||||||
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
--test-tgt-patterns "$data_root_dir/$tgt_dir/$test_dirs/$tgt_files" \
|
||||||
--norms cosmology.vel --crop 128 --pad 42 \
|
--norms cosmology.vel --crop 256 --pad 42 \
|
||||||
--model VNet \
|
--model VNet \
|
||||||
--load-state best_model.pth \
|
--load-state best_model.pth \
|
||||||
--batches 1 --loader-workers 0 \
|
--batches 1 --loader-workers 0 \
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
#SBATCH --job-name=vel2vel
|
#SBATCH --job-name=vel2vel
|
||||||
#SBATCH --dependency=singleton
|
|
||||||
#SBATCH --output=%x-%j.out
|
#SBATCH --output=%x-%j.out
|
||||||
|
|
||||||
#SBATCH --partition=gpu
|
#SBATCH --partition=gpu
|
||||||
|
|
Loading…
Reference in a new issue