Add delay to scheduler by adv_delay
This commit is contained in:
parent
0a2fc9a9e9
commit
7f6578c63e
@ -72,7 +72,8 @@ def add_train_args(parser):
|
|||||||
parser.add_argument('--cgan', action='store_true',
|
parser.add_argument('--cgan', action='store_true',
|
||||||
help='enable conditional GAN')
|
help='enable conditional GAN')
|
||||||
parser.add_argument('--adv-delay', default=0, type=int,
|
parser.add_argument('--adv-delay', default=0, type=int,
|
||||||
help='epoch before updating the generator with adversarial loss')
|
help='epoch before updating the generator with adversarial loss, '
|
||||||
|
'and the learning rate with scheduler')
|
||||||
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
|
@ -187,16 +187,21 @@ def gpu_worker(local_rank, args):
|
|||||||
args)
|
args)
|
||||||
epoch_loss = val_loss
|
epoch_loss = val_loss
|
||||||
|
|
||||||
scheduler.step(epoch_loss[0])
|
if epoch >= args.adv_delay:
|
||||||
if args.adv:
|
scheduler.step(epoch_loss[0])
|
||||||
adv_scheduler.step(epoch_loss[0])
|
if args.adv:
|
||||||
|
adv_scheduler.step(epoch_loss[0])
|
||||||
|
else:
|
||||||
|
scheduler.last_epoch = epoch
|
||||||
|
if args.adv:
|
||||||
|
adv_scheduler.last_epoch = epoch
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
print(end='', flush=True)
|
print(end='', flush=True)
|
||||||
args.logger.close()
|
args.logger.close()
|
||||||
|
|
||||||
is_best = min_loss is None or epoch_loss[0] < min_loss[0]
|
is_best = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||||
if is_best:
|
if is_best and epoch >= args.adv_delay:
|
||||||
min_loss = epoch_loss
|
min_loss = epoch_loss
|
||||||
|
|
||||||
state = {
|
state = {
|
||||||
|
Loading…
Reference in New Issue
Block a user