Remove label smoothing for generator

This commit is contained in:
Yin Li 2020-02-14 09:22:57 -05:00
parent b67079bf72
commit 5cb4a1bbae
2 changed files with 7 additions and 4 deletions

View File

@ -78,8 +78,9 @@ def add_train_args(parser):
help='enable conditional GAN') help='enable conditional GAN')
parser.add_argument('--adv-start', default=0, type=int, parser.add_argument('--adv-start', default=0, type=int,
help='epoch to start adversarial training') help='epoch to start adversarial training')
parser.add_argument('--adv-real-label', default=1, type=float, parser.add_argument('--adv-label-smoothing', default=1, type=float,
help='label for real samples, e.g. 0.9 for label smoothing') help='label of real samples for discriminator, '
'e.g. 0.9 for label smoothing')
parser.add_argument('--loss-fraction', default=0.5, type=float, parser.add_argument('--loss-fraction', default=0.5, type=float,
help='final fraction of loss (vs adv-loss)') help='final fraction of loss (vs adv-loss)')
parser.add_argument('--loss-halflife', default=20, type=float, parser.add_argument('--loss-halflife', default=20, type=float,

View File

@ -262,7 +262,9 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
# adv_loss: discriminator (adv_model) loss # adv_loss: discriminator (adv_model) loss
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device) epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
fake = torch.zeros([1], dtype=torch.float32, device=device) fake = torch.zeros([1], dtype=torch.float32, device=device)
real = torch.full([1], args.adv_real_label, dtype=torch.float32, device=device) real = torch.ones([1], dtype=torch.float32, device=device)
adv_real = torch.full([1], args.adv_label_smoothing, dtype=torch.float32,
device=device)
for i, (input, target) in enumerate(loader): for i, (input, target) in enumerate(loader):
input = input.to(device, non_blocking=True) input = input.to(device, non_blocking=True)
@ -290,7 +292,7 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
set_requires_grad(adv_model, True) set_requires_grad(adv_model, True)
eval = adv_model([output.detach(), target]) eval = adv_model([output.detach(), target])
adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, real]) adv_loss_fake, adv_loss_real = adv_criterion(eval, [fake, adv_real])
epoch_loss[3] += adv_loss_fake.item() epoch_loss[3] += adv_loss_fake.item()
epoch_loss[4] += adv_loss_real.item() epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)