Change ReduceLROnPlateau to CyclicLR
This commit is contained in:
parent
0211eed0ec
commit
b253bb687b
@ -88,7 +88,9 @@ def gpu_worker(local_rank, args):
|
|||||||
#momentum=args.momentum,
|
#momentum=args.momentum,
|
||||||
#weight_decay=args.weight_decay
|
#weight_decay=args.weight_decay
|
||||||
)
|
)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
|
||||||
|
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
|
||||||
|
base_lr=args.lr * 1e-2, max_lr=args.lr, cycle_momentum=False)
|
||||||
|
|
||||||
if args.load_state:
|
if args.load_state:
|
||||||
state = torch.load(args.load_state, map_location=args.device)
|
state = torch.load(args.load_state, map_location=args.device)
|
||||||
@ -117,11 +119,11 @@ def gpu_worker(local_rank, args):
|
|||||||
|
|
||||||
for epoch in range(args.start_epoch, args.epochs):
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
train(epoch, train_loader, model, criterion, optimizer, args)
|
train(epoch, train_loader, model, criterion, optimizer, scheduler, args)
|
||||||
|
|
||||||
val_loss = validate(epoch, val_loader, model, criterion, args)
|
val_loss = validate(epoch, val_loader, model, criterion, args)
|
||||||
|
|
||||||
scheduler.step(val_loss)
|
#scheduler.step(val_loss)
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
args.logger.close()
|
args.logger.close()
|
||||||
@ -145,7 +147,7 @@ def gpu_worker(local_rank, args):
|
|||||||
destroy_process_group()
|
destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
def train(epoch, loader, model, criterion, optimizer, args):
|
def train(epoch, loader, model, criterion, optimizer, scheduler, args):
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
for i, (input, target) in enumerate(loader):
|
for i, (input, target) in enumerate(loader):
|
||||||
@ -161,6 +163,9 @@ def train(epoch, loader, model, criterion, optimizer, args):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
batch = epoch * len(loader) + i + 1
|
batch = epoch * len(loader) + i + 1
|
||||||
if batch % args.log_interval == 0:
|
if batch % args.log_interval == 0:
|
||||||
all_reduce(loss)
|
all_reduce(loss)
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
#SBATCH --gres=gpu:v100-32gb:4
|
#SBATCH --gres=gpu:v100-32gb:4
|
||||||
|
|
||||||
#SBATCH --exclusive
|
#SBATCH --exclusive
|
||||||
#SBATCH --nodes=2
|
#SBATCH --nodes=4
|
||||||
#SBATCH --mem=0
|
#SBATCH --mem=0
|
||||||
#SBATCH --time=7-00:00:00
|
#SBATCH --time=7-00:00:00
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ srun m2m.py train \
|
|||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.dis --augment \
|
--in-channels 3 --out-channels 3 --norms cosmology.dis --augment \
|
||||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002
|
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||||
# --load-state checkpoint.pth
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
#SBATCH --gres=gpu:v100-32gb:4
|
#SBATCH --gres=gpu:v100-32gb:4
|
||||||
|
|
||||||
#SBATCH --exclusive
|
#SBATCH --exclusive
|
||||||
#SBATCH --nodes=2
|
#SBATCH --nodes=4
|
||||||
#SBATCH --mem=0
|
#SBATCH --mem=0
|
||||||
#SBATCH --time=7-00:00:00
|
#SBATCH --time=7-00:00:00
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ srun m2m.py train \
|
|||||||
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
--val-in-patterns "$data_root_dir/$in_dir/$val_dirs/$in_files" \
|
||||||
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
--val-tgt-patterns "$data_root_dir/$tgt_dir/$val_dirs/$tgt_files" \
|
||||||
--in-channels 3 --out-channels 3 --norms cosmology.vel --augment \
|
--in-channels 3 --out-channels 3 --norms cosmology.vel --augment \
|
||||||
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.0002
|
--epochs 1024 --batches 3 --loader-workers 3 --lr 0.001
|
||||||
# --load-state checkpoint.pth
|
# --load-state checkpoint.pth
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user