Change ReduceLROnPlateau to CyclicLR
This commit is contained in:
parent
0211eed0ec
commit
7540154eba
@ -1,4 +1,4 @@
|
||||
import os
|
||||
mport os
|
||||
import shutil
|
||||
import torch
|
||||
from torch.multiprocessing import spawn
|
||||
@ -88,7 +88,9 @@ def gpu_worker(local_rank, args):
|
||||
#momentum=args.momentum,
|
||||
#weight_decay=args.weight_decay
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
|
||||
base_lr=args.lr * 1e-2, max_lr=args.lr)
|
||||
|
||||
if args.load_state:
|
||||
state = torch.load(args.load_state, map_location=args.device)
|
||||
@ -117,11 +119,11 @@ def gpu_worker(local_rank, args):
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
train(epoch, train_loader, model, criterion, optimizer, args)
|
||||
train(epoch, train_loader, model, criterion, optimizer, scheduler, args)
|
||||
|
||||
val_loss = validate(epoch, val_loader, model, criterion, args)
|
||||
|
||||
scheduler.step(val_loss)
|
||||
#scheduler.step(val_loss)
|
||||
|
||||
if args.rank == 0:
|
||||
args.logger.close()
|
||||
@ -145,7 +147,7 @@ def gpu_worker(local_rank, args):
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
def train(epoch, loader, model, criterion, optimizer, args):
|
||||
def train(epoch, loader, model, criterion, optimizer, scheduler, args):
|
||||
model.train()
|
||||
|
||||
for i, (input, target) in enumerate(loader):
|
||||
@ -161,6 +163,9 @@ def train(epoch, loader, model, criterion, optimizer, args):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
batch = epoch * len(loader) + i + 1
|
||||
if batch % args.log_interval == 0:
|
||||
all_reduce(loss)
|
||||
|
Loading…
Reference in New Issue
Block a user