Change __dict__ to getattr

This commit is contained in:
Yin Li 2019-12-23 16:04:35 -05:00
parent 77710bc8a3
commit 848dc87169
7 changed files with 16 additions and 13 deletions

View file

@ -4,7 +4,7 @@ from . import cosmology
def import_norm(path):
mod, func = path.rsplit('.', 1)
mod, fun = path.rsplit('.', 1)
mod = import_module('.' + mod, __name__)
func = getattr(mod, func)
return func
fun = getattr(mod, fun)
return fun

View file

@ -25,8 +25,10 @@ def test(args):
in_channels, out_channels = test_dataset.channels
model = models.__dict__[args.model](in_channels, out_channels)
criterion = torch.nn.__dict__[args.criterion]()
model = getattr(models, args.model)
model = model(in_channels, out_channels)
criterion = getattr(torch.nn, args.criterion)
criterion = criterion()
device = torch.device('cpu')
state = torch.load(args.load_state, map_location=device)

View file

@ -82,21 +82,24 @@ def gpu_worker(local_rank, args):
in_channels, out_channels = train_dataset.channels
model = models.__dict__[args.model](in_channels, out_channels)
model = getattr(models, args.model)
model = model(in_channels, out_channels)
model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device])
criterion = torch.nn.__dict__[args.criterion]()
criterion = getattr(torch.nn, args.criterion)
criterion = criterion()
criterion.to(args.device)
optimizer = torch.optim.__dict__[args.optimizer](
optimizer = getattr(torch.optim, args.optimizer)
optimizer = optimizer(
model.parameters(),
lr=args.lr,
#momentum=args.momentum,
weight_decay=args.weight_decay,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.5, patience=2, verbose=True)
factor=0.5, patience=3, verbose=True)
if args.load_state:
state = torch.load(args.load_state, map_location=args.device)