From f2e9af6d5f222d03c2709135126d6e30d2e5e266 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Sun, 8 Dec 2019 20:58:46 -0500 Subject: [PATCH] Revert scheduler to ReduceLROnPlateau --- map2map/train.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/map2map/train.py b/map2map/train.py index 34c8093..bc78989 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -88,9 +88,8 @@ def gpu_worker(local_rank, args): #momentum=args.momentum, #weight_decay=args.weight_decay ) - #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) - scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, - base_lr=args.lr * 1e-2, max_lr=args.lr, cycle_momentum=False) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, + factor=0.1, verbose=True) if args.load_state: state = torch.load(args.load_state, map_location=args.device) @@ -123,7 +122,7 @@ def gpu_worker(local_rank, args): val_loss = validate(epoch, val_loader, model, criterion, args) - #scheduler.step(val_loss) + scheduler.step(val_loss) if args.rank == 0: args.logger.close() @@ -163,8 +162,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, args): loss.backward() optimizer.step() - if scheduler is not None: - scheduler.step() + #if scheduler is not None: # for batch scheduler + #scheduler.step() batch = epoch * len(loader) + i + 1 if batch % args.log_interval == 0: