Change adversary optim scheduler to sync with model optim scheduler

This commit is contained in:
Yin Li 2020-01-20 16:25:15 -05:00
parent e7d2435a96
commit 2b7e559910

View File

@ -102,7 +102,7 @@ def gpu_worker(local_rank, args):
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
) )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.5, patience=3, verbose=True) factor=0.1, patience=10, verbose=True)
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
args.adv = args.adv_model is not None args.adv = args.adv_model is not None
@ -130,7 +130,8 @@ def gpu_worker(local_rank, args):
betas=(0.5, 0.999), betas=(0.5, 0.999),
weight_decay=args.adv_weight_decay, weight_decay=args.adv_weight_decay,
) )
adv_scheduler = torch.optim.lr_scheduler.StepLR(adv_optimizer, 30, gamma=0.1) adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
factor=0.1, patience=10, verbose=True)
if args.load_state: if args.load_state:
state = torch.load(args.load_state, map_location=args.device) state = torch.load(args.load_state, map_location=args.device)
@ -185,7 +186,7 @@ def gpu_worker(local_rank, args):
scheduler.step(epoch_loss[0]) scheduler.step(epoch_loss[0])
if args.adv: if args.adv:
adv_scheduler.step() adv_scheduler.step(epoch_loss[0])
if args.rank == 0: if args.rank == 0:
print(end='', flush=True) print(end='', flush=True)