From 044e4c93762ffde8a7ebcddf360365d153b2c3c9 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Tue, 3 Dec 2019 17:40:08 -0500 Subject: [PATCH] Change ReduceLROnPlateau to CyclicLR --- map2map/train.py | 13 +++++++++---- scripts/dis2dis.slurm | 4 ++-- scripts/vel2vel.slurm | 4 ++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/map2map/train.py b/map2map/train.py index 2ae7d87..120890c 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -88,7 +88,9 @@ def gpu_worker(local_rank, args): #momentum=args.momentum, #weight_decay=args.weight_decay ) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, + base_lr=args.lr * 1e-2, max_lr=args.lr) if args.load_state: state = torch.load(args.load_state, map_location=args.device) @@ -117,11 +119,11 @@ def gpu_worker(local_rank, args): for epoch in range(args.start_epoch, args.epochs): train_sampler.set_epoch(epoch) - train(epoch, train_loader, model, criterion, optimizer, args) + train(epoch, train_loader, model, criterion, optimizer, scheduler, args) val_loss = validate(epoch, val_loader, model, criterion, args) - scheduler.step(val_loss) + #scheduler.step(val_loss) if args.rank == 0: args.logger.close() @@ -145,7 +147,7 @@ def gpu_worker(local_rank, args): destroy_process_group() -def train(epoch, loader, model, criterion, optimizer, args): +def train(epoch, loader, model, criterion, optimizer, scheduler, args): model.train() for i, (input, target) in enumerate(loader): @@ -161,6 +163,9 @@ def train(epoch, loader, model, criterion, optimizer, args): loss.backward() optimizer.step() + if scheduler is not None: + scheduler.step() + batch = epoch * len(loader) + i + 1 if batch % args.log_interval == 0: all_reduce(loss) diff --git a/scripts/dis2dis.slurm b/scripts/dis2dis.slurm index 61cdc3f..cc347af 100644 --- a/scripts/dis2dis.slurm +++ b/scripts/dis2dis.slurm @@ -9,7 +9,7 @@ #SBATCH --gres=gpu:v100-32gb:4 #SBATCH --exclusive -#SBATCH --nodes=2 +#SBATCH --nodes=4 #SBATCH --mem=0 #SBATCH --time=7-00:00:00 @@ -46,7 +46,7 @@ srun m2m.py train \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --in-channels 3 --out-channels 3 --norms cosmology.dis --augment \ - --epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002 + --epochs 1024 --batches 3 --loader-workers 3 --lr 0.001 # --load-state checkpoint.pth diff --git a/scripts/vel2vel.slurm b/scripts/vel2vel.slurm index 2e3caa9..30cfe0c 100644 --- a/scripts/vel2vel.slurm +++ b/scripts/vel2vel.slurm @@ -9,7 +9,7 @@ #SBATCH --gres=gpu:v100-32gb:4 #SBATCH --exclusive -#SBATCH --nodes=2 +#SBATCH --nodes=4 #SBATCH --mem=0 #SBATCH --time=7-00:00:00 @@ -46,7 +46,7 @@ srun m2m.py train \ --val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \ --val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \ --in-channels 3 --out-channels 3 --norms cosmology.vel --augment \ - --epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002 + --epochs 1024 --batches 3 --loader-workers 3 --lr 0.001 # --load-state checkpoint.pth