Change ReduceLROnPlateau to CyclicLR

This commit is contained in:
Yin Li 2019-12-03 17:40:08 -05:00
parent 0211eed0ec
commit 7540154eba

View File

@ -1,4 +1,4 @@
import os mport os
import shutil import shutil
import torch import torch
from torch.multiprocessing import spawn from torch.multiprocessing import spawn
@ -88,7 +88,9 @@ def gpu_worker(local_rank, args):
#momentum=args.momentum, #momentum=args.momentum,
#weight_decay=args.weight_decay #weight_decay=args.weight_decay
) )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
base_lr=args.lr * 1e-2, max_lr=args.lr)
if args.load_state: if args.load_state:
state = torch.load(args.load_state, map_location=args.device) state = torch.load(args.load_state, map_location=args.device)
@ -117,11 +119,11 @@ def gpu_worker(local_rank, args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train(epoch, train_loader, model, criterion, optimizer, args) train(epoch, train_loader, model, criterion, optimizer, scheduler, args)
val_loss = validate(epoch, val_loader, model, criterion, args) val_loss = validate(epoch, val_loader, model, criterion, args)
scheduler.step(val_loss) #scheduler.step(val_loss)
if args.rank == 0: if args.rank == 0:
args.logger.close() args.logger.close()
@ -145,7 +147,7 @@ def gpu_worker(local_rank, args):
destroy_process_group() destroy_process_group()
def train(epoch, loader, model, criterion, optimizer, args): def train(epoch, loader, model, criterion, optimizer, scheduler, args):
model.train() model.train()
for i, (input, target) in enumerate(loader): for i, (input, target) in enumerate(loader):
@ -161,6 +163,9 @@ def train(epoch, loader, model, criterion, optimizer, args):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if scheduler is not None:
scheduler.step()
batch = epoch * len(loader) + i + 1 batch = epoch * len(loader) + i + 1
if batch % args.log_interval == 0: if batch % args.log_interval == 0:
all_reduce(loss) all_reduce(loss)