From bf3fe86afe419b13958f9f50540b60b3e48bf271 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Tue, 3 Dec 2019 17:40:08 -0500 Subject: [PATCH] Change ReduceLROnPlateau to CyclicLR --- map2map/train.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/map2map/train.py b/map2map/train.py index 2ae7d87..d453b13 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -88,7 +88,9 @@ def gpu_worker(local_rank, args): #momentum=args.momentum, #weight_decay=args.weight_decay ) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + scheduler = torch.optim.CyclicLR(optimizer, base_lr=args.lr * 1e-2, + max_lr=args.lr) if args.load_state: state = torch.load(args.load_state, map_location=args.device) @@ -117,11 +119,11 @@ def gpu_worker(local_rank, args): for epoch in range(args.start_epoch, args.epochs): train_sampler.set_epoch(epoch) - train(epoch, train_loader, model, criterion, optimizer, args) + train(epoch, train_loader, model, criterion, optimizer, scheduler, args) val_loss = validate(epoch, val_loader, model, criterion, args) - scheduler.step(val_loss) + #scheduler.step(val_loss) if args.rank == 0: args.logger.close() @@ -145,7 +147,7 @@ def gpu_worker(local_rank, args): destroy_process_group() -def train(epoch, loader, model, criterion, optimizer, args): +def train(epoch, loader, model, criterion, optimizer, scheduler, args): model.train() for i, (input, target) in enumerate(loader): @@ -161,6 +163,9 @@ def train(epoch, loader, model, criterion, optimizer, args): loss.backward() optimizer.step() + if scheduler is not None: + scheduler.step() + batch = epoch * len(loader) + i + 1 if batch % args.log_interval == 0: all_reduce(loss)