Add simple label smoothing by --adv-real-label
This commit is contained in:
parent
d2840f01b0
commit
1818e11265
@ -78,6 +78,8 @@ def add_train_args(parser):
|
|||||||
help='enable conditional GAN')
|
help='enable conditional GAN')
|
||||||
parser.add_argument('--adv-start', default=0, type=int,
|
parser.add_argument('--adv-start', default=0, type=int,
|
||||||
help='epoch to start adversarial training')
|
help='epoch to start adversarial training')
|
||||||
|
parser.add_argument('--adv-real-label', default=1, type=float,
|
||||||
|
help='label for real samples, e.g. 0.9 for label smoothing')
|
||||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||||
help='final fraction of loss (vs adv-loss)')
|
help='final fraction of loss (vs adv-loss)')
|
||||||
parser.add_argument('--loss-halflife', default=20, type=float,
|
parser.add_argument('--loss-halflife', default=20, type=float,
|
||||||
|
@ -274,8 +274,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
# loss_adv: generator (model) adversarial loss
|
# loss_adv: generator (model) adversarial loss
|
||||||
# adv_loss: discriminator (adv_model) loss
|
# adv_loss: discriminator (adv_model) loss
|
||||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
||||||
real = torch.ones(1, dtype=torch.float32, device=device)
|
fake = torch.zeros([1], dtype=torch.float32, device=device)
|
||||||
fake = torch.zeros(1, dtype=torch.float32, device=device)
|
real = torch.full([1], args.adv_real_label, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
for i, (input, target) in enumerate(loader):
|
for i, (input, target) in enumerate(loader):
|
||||||
input = input.to(device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
@ -383,8 +383,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
||||||
fake = torch.zeros(1, dtype=torch.float32, device=device)
|
fake = torch.zeros([1], dtype=torch.float32, device=device)
|
||||||
real = torch.ones(1, dtype=torch.float32, device=device)
|
real = torch.ones([1], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for input, target in loader:
|
for input, target in loader:
|
||||||
|
Loading…
Reference in New Issue
Block a user