2019-11-30 21:32:45 +01:00
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
import torch
|
2020-01-10 02:24:46 +01:00
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.distributed as dist
|
2019-11-30 21:32:45 +01:00
|
|
|
from torch.multiprocessing import spawn
|
|
|
|
from torch.nn.parallel import DistributedDataParallel
|
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
|
|
|
|
from .data import FieldDataset
|
2019-12-10 03:53:27 +01:00
|
|
|
from . import models
|
|
|
|
from .models import narrow_like
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
|
|
|
|
def node_worker(args):
|
|
|
|
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
|
|
|
|
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
|
|
|
|
|
|
|
args.gpus_per_node = torch.cuda.device_count()
|
|
|
|
args.nodes = int(os.environ['SLURM_JOB_NUM_NODES'])
|
|
|
|
args.world_size = args.gpus_per_node * args.nodes
|
|
|
|
|
|
|
|
node = int(os.environ['SLURM_NODEID'])
|
|
|
|
if node == 0:
|
|
|
|
print(args)
|
|
|
|
args.node = node
|
|
|
|
|
|
|
|
spawn(gpu_worker, args=(args,), nprocs=args.gpus_per_node)
|
|
|
|
|
|
|
|
|
|
|
|
def gpu_worker(local_rank, args):
|
|
|
|
args.device = torch.device('cuda', local_rank)
|
|
|
|
torch.cuda.device(args.device)
|
|
|
|
|
|
|
|
args.rank = args.gpus_per_node * args.node + local_rank
|
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
dist.init_process_group(
|
2019-11-30 21:32:45 +01:00
|
|
|
backend=args.dist_backend,
|
|
|
|
init_method='env://',
|
|
|
|
world_size=args.world_size,
|
|
|
|
rank=args.rank
|
|
|
|
)
|
|
|
|
|
|
|
|
train_dataset = FieldDataset(
|
|
|
|
in_patterns=args.train_in_patterns,
|
|
|
|
tgt_patterns=args.train_tgt_patterns,
|
2019-12-18 23:06:16 +01:00
|
|
|
**vars(args),
|
2019-11-30 21:32:45 +01:00
|
|
|
)
|
2019-12-18 23:06:16 +01:00
|
|
|
if not args.div_data:
|
|
|
|
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
|
|
|
train_sampler = DistributedSampler(train_dataset)
|
2019-11-30 21:32:45 +01:00
|
|
|
train_loader = DataLoader(
|
|
|
|
train_dataset,
|
2019-12-02 00:53:38 +01:00
|
|
|
batch_size=args.batches,
|
2019-12-18 23:06:16 +01:00
|
|
|
shuffle=args.div_data,
|
|
|
|
sampler=None if args.div_data else train_sampler,
|
2019-12-02 00:53:38 +01:00
|
|
|
num_workers=args.loader_workers,
|
2019-11-30 21:32:45 +01:00
|
|
|
pin_memory=True
|
|
|
|
)
|
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
args.val = args.val_in_patterns is not None and \
|
|
|
|
args.val_tgt_patterns is not None
|
|
|
|
if args.val:
|
|
|
|
val_dataset = FieldDataset(
|
|
|
|
in_patterns=args.val_in_patterns,
|
|
|
|
tgt_patterns=args.val_tgt_patterns,
|
|
|
|
augment=False,
|
|
|
|
**{k: v for k, v in vars(args).items() if k != 'augment'},
|
|
|
|
)
|
|
|
|
if not args.div_data:
|
|
|
|
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
|
|
|
val_sampler = DistributedSampler(val_dataset)
|
|
|
|
val_loader = DataLoader(
|
|
|
|
val_dataset,
|
|
|
|
batch_size=args.batches,
|
|
|
|
shuffle=False,
|
|
|
|
sampler=None if args.div_data else val_sampler,
|
|
|
|
num_workers=args.loader_workers,
|
|
|
|
pin_memory=True
|
|
|
|
)
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2019-12-09 16:19:21 +01:00
|
|
|
in_channels, out_channels = train_dataset.channels
|
|
|
|
|
2019-12-23 22:04:35 +01:00
|
|
|
model = getattr(models, args.model)
|
|
|
|
model = model(in_channels, out_channels)
|
2019-11-30 21:32:45 +01:00
|
|
|
model.to(args.device)
|
2020-01-10 20:39:16 +01:00
|
|
|
model = DistributedDataParallel(model, device_ids=[args.device],
|
|
|
|
process_group=dist.new_group())
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2019-12-23 22:04:35 +01:00
|
|
|
criterion = getattr(torch.nn, args.criterion)
|
|
|
|
criterion = criterion()
|
2019-11-30 21:32:45 +01:00
|
|
|
criterion.to(args.device)
|
|
|
|
|
2019-12-23 22:04:35 +01:00
|
|
|
optimizer = getattr(torch.optim, args.optimizer)
|
|
|
|
optimizer = optimizer(
|
2019-11-30 21:32:45 +01:00
|
|
|
model.parameters(),
|
|
|
|
lr=args.lr,
|
|
|
|
#momentum=args.momentum,
|
2020-01-10 02:24:46 +01:00
|
|
|
betas=(0.5, 0.999),
|
2019-12-12 18:04:39 +01:00
|
|
|
weight_decay=args.weight_decay,
|
2019-11-30 21:32:45 +01:00
|
|
|
)
|
2019-12-09 02:58:46 +01:00
|
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
2020-01-20 22:25:15 +01:00
|
|
|
factor=0.1, patience=10, verbose=True)
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
|
|
|
|
args.adv = args.adv_model is not None
|
|
|
|
if args.adv:
|
|
|
|
adv_model = getattr(models, args.adv_model)
|
|
|
|
adv_model = adv_model(in_channels + out_channels
|
|
|
|
if args.cgan else out_channels, 1)
|
|
|
|
adv_model.to(args.device)
|
2020-01-10 20:39:16 +01:00
|
|
|
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device],
|
|
|
|
process_group=dist.new_group())
|
2020-01-10 02:24:46 +01:00
|
|
|
|
|
|
|
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
|
|
|
adv_criterion = adv_criterion()
|
|
|
|
adv_criterion.to(args.device)
|
|
|
|
|
|
|
|
if args.adv_lr is None:
|
|
|
|
args.adv_lr = args.lr
|
|
|
|
if args.adv_weight_decay is None:
|
|
|
|
args.adv_weight_decay = args.weight_decay
|
|
|
|
|
|
|
|
adv_optimizer = getattr(torch.optim, args.optimizer)
|
|
|
|
adv_optimizer = adv_optimizer(
|
|
|
|
adv_model.parameters(),
|
|
|
|
lr=args.adv_lr,
|
|
|
|
betas=(0.5, 0.999),
|
|
|
|
weight_decay=args.adv_weight_decay,
|
|
|
|
)
|
2020-01-20 22:25:15 +01:00
|
|
|
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
|
|
|
factor=0.1, patience=10, verbose=True)
|
2020-01-10 02:24:46 +01:00
|
|
|
|
2019-11-30 21:32:45 +01:00
|
|
|
if args.load_state:
|
2019-12-02 00:53:38 +01:00
|
|
|
state = torch.load(args.load_state, map_location=args.device)
|
|
|
|
args.start_epoch = state['epoch']
|
2019-12-09 03:00:51 +01:00
|
|
|
model.module.load_state_dict(state['model'])
|
2019-12-02 00:53:38 +01:00
|
|
|
optimizer.load_state_dict(state['optimizer'])
|
|
|
|
scheduler.load_state_dict(state['scheduler'])
|
2020-01-10 02:24:46 +01:00
|
|
|
if 'adv_model' in state and args.adv:
|
|
|
|
adv_model.module.load_state_dict(state['adv_model'])
|
|
|
|
adv_optimizer.load_state_dict(state['adv_optimizer'])
|
|
|
|
adv_scheduler.load_state_dict(state['adv_scheduler'])
|
2019-12-02 00:53:38 +01:00
|
|
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
2019-11-30 21:32:45 +01:00
|
|
|
if args.rank == 0:
|
2019-12-02 00:53:38 +01:00
|
|
|
min_loss = state['min_loss']
|
|
|
|
print('checkpoint at epoch {} loaded from {}'.format(
|
|
|
|
state['epoch'], args.load_state))
|
|
|
|
del state
|
2019-11-30 21:32:45 +01:00
|
|
|
else:
|
2020-01-10 02:24:46 +01:00
|
|
|
# def init_weights(m):
|
|
|
|
# classname = m.__class__.__name__
|
|
|
|
# if isinstance(m, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
|
|
|
|
# m.weight.data.normal_(0.0, 0.02)
|
|
|
|
# elif isinstance(m, torch.nn.BatchNorm3d):
|
|
|
|
# m.weight.data.normal_(1.0, 0.02)
|
|
|
|
# m.bias.data.fill_(0)
|
|
|
|
# model.apply(init_weights)
|
|
|
|
#
|
2019-11-30 21:32:45 +01:00
|
|
|
args.start_epoch = 0
|
|
|
|
if args.rank == 0:
|
|
|
|
min_loss = None
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
|
|
|
|
|
|
|
if args.rank == 0:
|
|
|
|
args.logger = SummaryWriter()
|
|
|
|
|
|
|
|
for epoch in range(args.start_epoch, args.epochs):
|
2019-12-18 23:06:16 +01:00
|
|
|
if not args.div_data:
|
|
|
|
train_sampler.set_epoch(epoch)
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
train_loss = train(epoch, train_loader,
|
|
|
|
model, criterion, optimizer, scheduler,
|
|
|
|
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
|
|
|
args)
|
|
|
|
epoch_loss = train_loss
|
|
|
|
|
|
|
|
if args.val:
|
|
|
|
val_loss = validate(epoch, val_loader,
|
|
|
|
model, criterion, adv_model, adv_criterion,
|
|
|
|
args)
|
|
|
|
epoch_loss = val_loss
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
scheduler.step(epoch_loss[0])
|
2020-01-20 01:10:14 +01:00
|
|
|
if args.adv:
|
2020-01-20 22:25:15 +01:00
|
|
|
adv_scheduler.step(epoch_loss[0])
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
if args.rank == 0:
|
2019-12-09 03:02:08 +01:00
|
|
|
print(end='', flush=True)
|
2019-12-12 21:25:45 +01:00
|
|
|
args.logger.close()
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
is_best = min_loss is None or epoch_loss[0] < min_loss[0]
|
|
|
|
if is_best:
|
|
|
|
min_loss = epoch_loss
|
|
|
|
|
2019-12-02 00:53:38 +01:00
|
|
|
state = {
|
2019-11-30 21:32:45 +01:00
|
|
|
'epoch': epoch + 1,
|
2019-12-09 03:00:51 +01:00
|
|
|
'model': model.module.state_dict(),
|
2020-01-10 02:24:46 +01:00
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
'scheduler': scheduler.state_dict(),
|
|
|
|
'rng': torch.get_rng_state(),
|
2019-11-30 21:32:45 +01:00
|
|
|
'min_loss': min_loss,
|
|
|
|
}
|
2020-01-10 02:24:46 +01:00
|
|
|
if args.adv:
|
|
|
|
state.update({
|
|
|
|
'adv_model': adv_model.module.state_dict(),
|
|
|
|
'adv_optimizer': adv_optimizer.state_dict(),
|
|
|
|
'adv_scheduler': adv_scheduler.state_dict(),
|
|
|
|
})
|
2019-12-12 22:51:47 +01:00
|
|
|
ckpt_file = 'checkpoint.pth'
|
|
|
|
best_file = 'best_model_{}.pth'
|
2019-12-13 01:26:57 +01:00
|
|
|
torch.save(state, ckpt_file)
|
2019-12-02 00:53:38 +01:00
|
|
|
del state
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
if is_best:
|
2019-12-13 01:26:57 +01:00
|
|
|
shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
|
2020-01-07 02:20:05 +01:00
|
|
|
#if os.path.isfile(best_file.format(epoch)):
|
|
|
|
# os.remove(best_file.format(epoch))
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
dist.destroy_process_group()
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|
|
|
adv_model, adv_criterion, adv_optimizer, adv_scheduler, args):
|
2019-11-30 21:32:45 +01:00
|
|
|
model.train()
|
2020-01-10 02:24:46 +01:00
|
|
|
if args.adv:
|
|
|
|
adv_model.train()
|
|
|
|
|
|
|
|
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
|
|
|
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
for i, (input, target) in enumerate(loader):
|
|
|
|
input = input.to(args.device, non_blocking=True)
|
|
|
|
target = target.to(args.device, non_blocking=True)
|
|
|
|
|
|
|
|
output = model(input)
|
2019-12-02 00:53:38 +01:00
|
|
|
target = narrow_like(target, output) # FIXME pad
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
loss = criterion(output, target)
|
2020-01-10 02:24:46 +01:00
|
|
|
epoch_loss[0] += loss.item()
|
|
|
|
|
|
|
|
if args.adv:
|
|
|
|
if args.cgan:
|
|
|
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
|
|
|
input = F.interpolate(input,
|
|
|
|
scale_factor=model.scale_factor, mode='trilinear')
|
|
|
|
input = narrow_like(input, output)
|
|
|
|
output = torch.cat([input, output], dim=1)
|
|
|
|
target = torch.cat([input, target], dim=1)
|
|
|
|
|
|
|
|
# discriminator
|
|
|
|
#
|
|
|
|
# outtgt = torch.cat([output.detach(), target], dim=0)
|
|
|
|
#
|
|
|
|
# eval_outtgt = adv_model(outtgt)
|
|
|
|
#
|
|
|
|
# fake = torch.zeros(1, dtype=torch.float32, device=args.device)
|
|
|
|
# fake = fake.expand_as(output.shape[0] + eval_outtgt.shape[1:])
|
|
|
|
# real = torch.ones(1, dtype=torch.float32, device=args.device)
|
|
|
|
# real = real.expand_as(target.shape[0] + eval_outtgt.shape[1:])
|
|
|
|
# fakereal = torch.cat([fake, real], dim=0)
|
|
|
|
|
|
|
|
eval_out = adv_model(output.detach())
|
|
|
|
fake = torch.zeros(1, dtype=torch.float32,
|
|
|
|
device=args.device).expand_as(eval_out)
|
|
|
|
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
|
|
|
epoch_loss[3] += adv_loss_fake.item()
|
|
|
|
|
|
|
|
eval_tgt = adv_model(target)
|
|
|
|
real = torch.ones(1, dtype=torch.float32,
|
|
|
|
device=args.device).expand_as(eval_tgt)
|
|
|
|
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
|
|
|
epoch_loss[4] += adv_loss_real.item()
|
|
|
|
|
|
|
|
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
|
|
|
epoch_loss[2] += adv_loss.item()
|
|
|
|
|
|
|
|
adv_optimizer.zero_grad()
|
|
|
|
adv_loss.backward()
|
|
|
|
adv_optimizer.step()
|
|
|
|
|
|
|
|
# generator adversarial loss
|
|
|
|
|
|
|
|
eval_out = adv_model(output)
|
|
|
|
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
|
|
|
epoch_loss[1] += loss_adv.item()
|
|
|
|
|
2020-01-20 01:10:14 +01:00
|
|
|
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
|
|
|
|
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
2019-12-01 04:12:47 +01:00
|
|
|
batch = epoch * len(loader) + i + 1
|
2019-11-30 21:32:45 +01:00
|
|
|
if batch % args.log_interval == 0:
|
2020-01-10 02:24:46 +01:00
|
|
|
dist.all_reduce(loss)
|
2019-11-30 21:32:45 +01:00
|
|
|
loss /= args.world_size
|
|
|
|
if args.rank == 0:
|
2020-01-10 02:24:46 +01:00
|
|
|
args.logger.add_scalar('loss/batch/train', loss.item(),
|
|
|
|
global_step=batch)
|
|
|
|
if args.adv:
|
|
|
|
args.logger.add_scalar('loss/batch/train/adv/G',
|
|
|
|
loss_adv.item(), global_step=batch)
|
|
|
|
args.logger.add_scalars('loss/batch/train/adv/D', {
|
|
|
|
'total': adv_loss.item(),
|
|
|
|
'fake': adv_loss_fake.item(),
|
|
|
|
'real': adv_loss_real.item(),
|
|
|
|
}, global_step=batch)
|
|
|
|
|
|
|
|
dist.all_reduce(epoch_loss)
|
|
|
|
epoch_loss /= len(loader) * args.world_size
|
|
|
|
if args.rank == 0:
|
|
|
|
args.logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
|
|
|
global_step=epoch+1)
|
|
|
|
if args.adv:
|
|
|
|
args.logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
|
|
|
|
global_step=epoch+1)
|
|
|
|
args.logger.add_scalars('loss/epoch/train/adv/D', {
|
|
|
|
'total': epoch_loss[2],
|
|
|
|
'fake': epoch_loss[3],
|
|
|
|
'real': epoch_loss[4],
|
|
|
|
}, global_step=epoch+1)
|
|
|
|
|
|
|
|
return epoch_loss
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
2019-11-30 21:32:45 +01:00
|
|
|
model.eval()
|
2020-01-10 02:24:46 +01:00
|
|
|
if args.adv:
|
|
|
|
adv_model.eval()
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
|
|
|
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
2019-11-30 21:32:45 +01:00
|
|
|
|
|
|
|
with torch.no_grad():
|
2020-01-10 02:24:46 +01:00
|
|
|
for input, target in loader:
|
2019-11-30 21:32:45 +01:00
|
|
|
input = input.to(args.device, non_blocking=True)
|
|
|
|
target = target.to(args.device, non_blocking=True)
|
|
|
|
|
|
|
|
output = model(input)
|
2019-12-02 00:53:38 +01:00
|
|
|
target = narrow_like(target, output) # FIXME pad
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
loss = criterion(output, target)
|
|
|
|
epoch_loss[0] += loss.item()
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
if args.adv:
|
|
|
|
if args.cgan:
|
|
|
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
|
|
|
input = F.interpolate(input,
|
|
|
|
scale_factor=model.scale_factor, mode='trilinear')
|
|
|
|
input = narrow_like(input, output)
|
|
|
|
output = torch.cat([input, output], dim=1)
|
|
|
|
target = torch.cat([input, target], dim=1)
|
|
|
|
|
|
|
|
# discriminator
|
|
|
|
|
|
|
|
eval_out = adv_model(output)
|
|
|
|
fake = torch.zeros(1, dtype=torch.float32,
|
|
|
|
device=args.device).expand_as(eval_out) # FIXME criterion wrapper: both D&G; min reduction; expand_as
|
|
|
|
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
|
|
|
epoch_loss[3] += adv_loss_fake.item()
|
2019-11-30 21:32:45 +01:00
|
|
|
|
2020-01-10 02:24:46 +01:00
|
|
|
eval_tgt = adv_model(target)
|
|
|
|
real = torch.ones(1, dtype=torch.float32,
|
|
|
|
device=args.device).expand_as(eval_tgt)
|
|
|
|
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
|
|
|
epoch_loss[4] += adv_loss_real.item()
|
|
|
|
|
|
|
|
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
|
|
|
epoch_loss[2] += adv_loss.item()
|
|
|
|
|
|
|
|
# generator adversarial loss
|
|
|
|
|
|
|
|
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
|
|
|
epoch_loss[1] += loss_adv.item()
|
|
|
|
|
|
|
|
dist.all_reduce(epoch_loss)
|
|
|
|
epoch_loss /= len(loader) * args.world_size
|
|
|
|
if args.rank == 0:
|
|
|
|
args.logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
|
|
|
global_step=epoch+1)
|
|
|
|
if args.adv:
|
|
|
|
args.logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
|
|
|
|
global_step=epoch+1)
|
|
|
|
args.logger.add_scalars('loss/epoch/val/adv/D', {
|
|
|
|
'total': epoch_loss[2],
|
|
|
|
'fake': epoch_loss[3],
|
|
|
|
'real': epoch_loss[4],
|
|
|
|
}, global_step=epoch+1)
|
|
|
|
|
|
|
|
return epoch_loss
|