diff --git a/map2map/train.py b/map2map/train.py index 2ae7d87..d453b13 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -88,7 +88,9 @@ def gpu_worker(local_rank, args): #momentum=args.momentum, #weight_decay=args.weight_decay ) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + scheduler = torch.optim.CyclicLR(optimizer, base_lr=args.lr * 1e-2, + max_lr=args.lr) if args.load_state: state = torch.load(args.load_state, map_location=args.device) @@ -117,11 +119,11 @@ def gpu_worker(local_rank, args): for epoch in range(args.start_epoch, args.epochs): train_sampler.set_epoch(epoch) - train(epoch, train_loader, model, criterion, optimizer, args) + train(epoch, train_loader, model, criterion, optimizer, scheduler, args) val_loss = validate(epoch, val_loader, model, criterion, args) - scheduler.step(val_loss) + #scheduler.step(val_loss) if args.rank == 0: args.logger.close() @@ -145,7 +147,7 @@ def gpu_worker(local_rank, args): destroy_process_group() -def train(epoch, loader, model, criterion, optimizer, args): +def train(epoch, loader, model, criterion, optimizer, scheduler, args): model.train() for i, (input, target) in enumerate(loader): @@ -161,6 +163,9 @@ def train(epoch, loader, model, criterion, optimizer, args): loss.backward() optimizer.step() + if scheduler is not None: + scheduler.step() + batch = epoch * len(loader) + i + 1 if batch % args.log_interval == 0: all_reduce(loss)