Merge branch 'anneal_loss'
Conflicts: map2map/models/patchgan.py map2map/train.py
This commit is contained in:
commit
16b82fcc56
@ -74,9 +74,10 @@ def add_train_args(parser):
|
|||||||
help='enable minimum reduction in adversarial criterion')
|
help='enable minimum reduction in adversarial criterion')
|
||||||
parser.add_argument('--cgan', action='store_true',
|
parser.add_argument('--cgan', action='store_true',
|
||||||
help='enable conditional GAN')
|
help='enable conditional GAN')
|
||||||
parser.add_argument('--adv-delay', default=0, type=int,
|
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||||
help='epoch before updating the generator with adversarial loss, '
|
help='final fraction of loss (vs adv-loss)')
|
||||||
'and the learning rate with scheduler')
|
parser.add_argument('--loss-halflife', default=20, type=float,
|
||||||
|
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||||
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
|
192
map2map/train.py
192
map2map/train.py
@ -18,7 +18,24 @@ from .models.adversary import adv_model_wrapper, adv_criterion_wrapper
|
|||||||
from .state import load_model_state_dict
|
from .state import load_model_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def set_runtime_default_args(args):
|
||||||
|
args.val = args.val_in_patterns is not None and \
|
||||||
|
args.val_tgt_patterns is not None
|
||||||
|
|
||||||
|
args.adv = args.adv_model is not None
|
||||||
|
|
||||||
|
if args.adv:
|
||||||
|
if args.adv_lr is None:
|
||||||
|
args.adv_lr = args.lr
|
||||||
|
if args.adv_weight_decay is None:
|
||||||
|
args.adv_weight_decay = args.weight_decay
|
||||||
|
|
||||||
|
args.adv_epoch = 0 # epoch when adversary is initiated
|
||||||
|
|
||||||
|
|
||||||
def node_worker(args):
|
def node_worker(args):
|
||||||
|
set_runtime_default_args(args)
|
||||||
|
|
||||||
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
|
torch.manual_seed(args.seed) # NOTE: why here not in gpu_worker?
|
||||||
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
#torch.backends.cudnn.deterministic = True # NOTE: test perf
|
||||||
|
|
||||||
@ -31,20 +48,20 @@ def node_worker(args):
|
|||||||
pprint(vars(args))
|
pprint(vars(args))
|
||||||
args.node = node
|
args.node = node
|
||||||
|
|
||||||
spawn(gpu_worker, args=(args,), nprocs=args.gpus_per_node)
|
spawn(gpu_worker, args=(node, args), nprocs=args.gpus_per_node)
|
||||||
|
|
||||||
|
|
||||||
def gpu_worker(local_rank, args):
|
def gpu_worker(local_rank, node, args):
|
||||||
args.device = torch.device('cuda', local_rank)
|
device = torch.device('cuda', local_rank)
|
||||||
torch.cuda.device(args.device)
|
torch.cuda.device(device)
|
||||||
|
|
||||||
args.rank = args.gpus_per_node * args.node + local_rank
|
rank = args.gpus_per_node * node + local_rank
|
||||||
|
|
||||||
dist.init_process_group(
|
dist.init_process_group(
|
||||||
backend=args.dist_backend,
|
backend=args.dist_backend,
|
||||||
init_method='env://',
|
init_method='env://',
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
rank=args.rank
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset = FieldDataset(
|
train_dataset = FieldDataset(
|
||||||
@ -59,7 +76,7 @@ def gpu_worker(local_rank, args):
|
|||||||
noise_chan=args.noise_chan,
|
noise_chan=args.noise_chan,
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
rank=args.rank,
|
rank=rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
@ -71,11 +88,9 @@ def gpu_worker(local_rank, args):
|
|||||||
shuffle=args.div_data,
|
shuffle=args.div_data,
|
||||||
sampler=None if args.div_data else train_sampler,
|
sampler=None if args.div_data else train_sampler,
|
||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
pin_memory=True
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
args.val = args.val_in_patterns is not None and \
|
|
||||||
args.val_tgt_patterns is not None
|
|
||||||
if args.val:
|
if args.val:
|
||||||
val_dataset = FieldDataset(
|
val_dataset = FieldDataset(
|
||||||
in_patterns=args.val_in_patterns,
|
in_patterns=args.val_in_patterns,
|
||||||
@ -89,7 +104,7 @@ def gpu_worker(local_rank, args):
|
|||||||
noise_chan=args.noise_chan,
|
noise_chan=args.noise_chan,
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
rank=args.rank,
|
rank=rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
@ -101,20 +116,20 @@ def gpu_worker(local_rank, args):
|
|||||||
shuffle=False,
|
shuffle=False,
|
||||||
sampler=None if args.div_data else val_sampler,
|
sampler=None if args.div_data else val_sampler,
|
||||||
num_workers=args.loader_workers,
|
num_workers=args.loader_workers,
|
||||||
pin_memory=True
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||||
|
|
||||||
model = getattr(models, args.model)
|
model = getattr(models, args.model)
|
||||||
model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan))
|
model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan))
|
||||||
model.to(args.device)
|
model.to(device)
|
||||||
model = DistributedDataParallel(model, device_ids=[args.device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
criterion = getattr(torch.nn, args.criterion)
|
criterion = getattr(torch.nn, args.criterion)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
criterion.to(args.device)
|
criterion.to(device)
|
||||||
|
|
||||||
optimizer = getattr(torch.optim, args.optimizer)
|
optimizer = getattr(torch.optim, args.optimizer)
|
||||||
optimizer = optimizer(
|
optimizer = optimizer(
|
||||||
@ -128,25 +143,19 @@ def gpu_worker(local_rank, args):
|
|||||||
factor=0.1, patience=10, verbose=True)
|
factor=0.1, patience=10, verbose=True)
|
||||||
|
|
||||||
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
|
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
|
||||||
args.adv = args.adv_model is not None
|
|
||||||
if args.adv:
|
if args.adv:
|
||||||
adv_model = getattr(models, args.adv_model)
|
adv_model = getattr(models, args.adv_model)
|
||||||
adv_model = adv_model_wrapper(adv_model)
|
adv_model = adv_model_wrapper(adv_model)
|
||||||
adv_model = adv_model(sum(args.in_chan + args.out_chan)
|
adv_model = adv_model(sum(args.in_chan + args.out_chan)
|
||||||
if args.cgan else sum(args.out_chan), 1)
|
if args.cgan else sum(args.out_chan), 1)
|
||||||
adv_model.to(args.device)
|
adv_model.to(device)
|
||||||
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device],
|
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
||||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||||
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
||||||
adv_criterion.to(args.device)
|
adv_criterion.to(device)
|
||||||
|
|
||||||
if args.adv_lr is None:
|
|
||||||
args.adv_lr = args.lr
|
|
||||||
if args.adv_weight_decay is None:
|
|
||||||
args.adv_weight_decay = args.weight_decay
|
|
||||||
|
|
||||||
adv_optimizer = getattr(torch.optim, args.optimizer)
|
adv_optimizer = getattr(torch.optim, args.optimizer)
|
||||||
adv_optimizer = adv_optimizer(
|
adv_optimizer = adv_optimizer(
|
||||||
@ -159,21 +168,31 @@ def gpu_worker(local_rank, args):
|
|||||||
factor=0.1, patience=10, verbose=True)
|
factor=0.1, patience=10, verbose=True)
|
||||||
|
|
||||||
if args.load_state:
|
if args.load_state:
|
||||||
state = torch.load(args.load_state, map_location=args.device)
|
state = torch.load(args.load_state, map_location=device)
|
||||||
args.start_epoch = state['epoch']
|
|
||||||
args.adv_delay += args.start_epoch
|
start_epoch = state['epoch']
|
||||||
|
|
||||||
load_model_state_dict(model.module, state['model'],
|
load_model_state_dict(model.module, state['model'],
|
||||||
strict=args.load_state_strict)
|
strict=args.load_state_strict)
|
||||||
if 'adv_model' in state and args.adv:
|
|
||||||
load_model_state_dict(adv_model.module, state['adv_model'],
|
if args.adv:
|
||||||
strict=args.load_state_strict)
|
if 'adv_model' in state:
|
||||||
|
args.adv_epoch = state['adv_epoch']
|
||||||
|
|
||||||
|
load_model_state_dict(adv_model.module, state['adv_model'],
|
||||||
|
strict=args.load_state_strict)
|
||||||
|
else:
|
||||||
|
args.adv_epoch = start_epoch
|
||||||
|
|
||||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||||
if args.rank == 0:
|
|
||||||
|
if rank == 0:
|
||||||
min_loss = state['min_loss']
|
min_loss = state['min_loss']
|
||||||
if 'adv_model' not in state and args.adv:
|
if 'adv_model' not in state and args.adv:
|
||||||
min_loss = None # restarting with adversary wipes the record
|
min_loss = None # restarting with adversary wipes the record
|
||||||
print('checkpoint at epoch {} loaded from {}'.format(
|
print('checkpoint at epoch {} loaded from {}'.format(
|
||||||
state['epoch'], args.load_state))
|
state['epoch'], args.load_state))
|
||||||
|
|
||||||
del state
|
del state
|
||||||
else:
|
else:
|
||||||
# def init_weights(m):
|
# def init_weights(m):
|
||||||
@ -185,44 +204,40 @@ def gpu_worker(local_rank, args):
|
|||||||
# m.bias.data.fill_(0)
|
# m.bias.data.fill_(0)
|
||||||
# model.apply(init_weights)
|
# model.apply(init_weights)
|
||||||
#
|
#
|
||||||
args.start_epoch = 0
|
start_epoch = 0
|
||||||
if args.rank == 0:
|
|
||||||
|
if rank == 0:
|
||||||
min_loss = None
|
min_loss = None
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
||||||
|
|
||||||
if args.rank == 0:
|
logger = None
|
||||||
args.logger = SummaryWriter()
|
if rank == 0:
|
||||||
|
logger = SummaryWriter()
|
||||||
|
|
||||||
for epoch in range(args.start_epoch, args.epochs):
|
for epoch in range(start_epoch, args.epochs):
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
train_loss = train(epoch, train_loader,
|
train_loss = train(epoch, train_loader,
|
||||||
model, criterion, optimizer, scheduler,
|
model, criterion, optimizer, scheduler,
|
||||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
||||||
args)
|
logger, device, args)
|
||||||
epoch_loss = train_loss
|
epoch_loss = train_loss
|
||||||
|
|
||||||
if args.val:
|
if args.val:
|
||||||
val_loss = validate(epoch, val_loader,
|
val_loss = validate(epoch, val_loader,
|
||||||
model, criterion, adv_model, adv_criterion,
|
model, criterion, adv_model, adv_criterion,
|
||||||
args)
|
logger, device, args)
|
||||||
epoch_loss = val_loss
|
epoch_loss = val_loss
|
||||||
|
|
||||||
if args.reduce_lr_on_plateau:
|
if args.reduce_lr_on_plateau:
|
||||||
if epoch >= args.adv_delay:
|
scheduler.step(epoch_loss[0])
|
||||||
scheduler.step(epoch_loss[0])
|
if args.adv:
|
||||||
if args.adv:
|
adv_scheduler.step(epoch_loss[0])
|
||||||
adv_scheduler.step(epoch_loss[0])
|
|
||||||
else:
|
|
||||||
scheduler.last_epoch = epoch
|
|
||||||
if args.adv:
|
|
||||||
adv_scheduler.last_epoch = epoch
|
|
||||||
|
|
||||||
if args.rank == 0:
|
if rank == 0:
|
||||||
print(end='', flush=True)
|
logger.close()
|
||||||
args.logger.close()
|
|
||||||
|
|
||||||
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||||
if good and epoch >= args.adv_delay:
|
if good and epoch >= args.adv_delay:
|
||||||
@ -236,6 +251,7 @@ def gpu_worker(local_rank, args):
|
|||||||
}
|
}
|
||||||
if args.adv:
|
if args.adv:
|
||||||
state.update({
|
state.update({
|
||||||
|
'adv_epoch': args.adv_epoch,
|
||||||
'adv_model': adv_model.module.state_dict(),
|
'adv_model': adv_model.module.state_dict(),
|
||||||
})
|
})
|
||||||
ckpt_file = 'checkpoint.pth'
|
ckpt_file = 'checkpoint.pth'
|
||||||
@ -252,22 +268,26 @@ def gpu_worker(local_rank, args):
|
|||||||
|
|
||||||
|
|
||||||
def train(epoch, loader, model, criterion, optimizer, scheduler,
|
def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler, args):
|
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
||||||
|
logger, device, args):
|
||||||
model.train()
|
model.train()
|
||||||
if args.adv:
|
if args.adv:
|
||||||
adv_model.train()
|
adv_model.train()
|
||||||
|
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
||||||
# loss: generator (model) supervised loss
|
# loss: generator (model) supervised loss
|
||||||
# loss_adv: generator (model) adversarial loss
|
# loss_adv: generator (model) adversarial loss
|
||||||
# adv_loss: discriminator (adv_model) loss
|
# adv_loss: discriminator (adv_model) loss
|
||||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
||||||
real = torch.ones(1, dtype=torch.float32, device=args.device)
|
real = torch.ones(1, dtype=torch.float32, device=device)
|
||||||
fake = torch.zeros(1, dtype=torch.float32, device=args.device)
|
fake = torch.zeros(1, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
for i, (input, target) in enumerate(loader):
|
for i, (input, target) in enumerate(loader):
|
||||||
input = input.to(args.device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
target = target.to(args.device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if args.noise_chan > 0:
|
if args.noise_chan > 0:
|
||||||
@ -291,9 +311,11 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
loss_adv, = adv_criterion(eval_out, real)
|
loss_adv, = adv_criterion(eval_out, real)
|
||||||
epoch_loss[1] += loss_adv.item()
|
epoch_loss[1] += loss_adv.item()
|
||||||
|
|
||||||
if epoch >= args.adv_delay:
|
r = loss.item() / (loss_adv.item() + 1e-8)
|
||||||
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
|
f = args.loss_fraction
|
||||||
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
|
e = epoch - args.adv_epoch
|
||||||
|
d = 0.5 ** (e / args.loss_halflife)
|
||||||
|
loss = (f + (1 - f) * d) * loss + (1 - f) * (1 - d) * r * loss_adv
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -315,37 +337,37 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
batch = epoch * len(loader) + i + 1
|
batch = epoch * len(loader) + i + 1
|
||||||
if batch % args.log_interval == 0:
|
if batch % args.log_interval == 0:
|
||||||
dist.all_reduce(loss)
|
dist.all_reduce(loss)
|
||||||
loss /= args.world_size
|
loss /= world_size
|
||||||
if args.rank == 0:
|
if rank == 0:
|
||||||
args.logger.add_scalar('loss/batch/train', loss.item(),
|
logger.add_scalar('loss/batch/train', loss.item(),
|
||||||
global_step=batch)
|
global_step=batch)
|
||||||
if args.adv:
|
if args.adv:
|
||||||
args.logger.add_scalar('loss/batch/train/adv/G',
|
logger.add_scalar('loss/batch/train/adv/G',
|
||||||
loss_adv.item(), global_step=batch)
|
loss_adv.item(), global_step=batch)
|
||||||
args.logger.add_scalars('loss/batch/train/adv/D', {
|
logger.add_scalars('loss/batch/train/adv/D', {
|
||||||
'total': adv_loss.item(),
|
'total': adv_loss.item(),
|
||||||
'fake': adv_loss_fake.item(),
|
'fake': adv_loss_fake.item(),
|
||||||
'real': adv_loss_real.item(),
|
'real': adv_loss_real.item(),
|
||||||
}, global_step=batch)
|
}, global_step=batch)
|
||||||
|
|
||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * args.world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if args.rank == 0:
|
if rank == 0:
|
||||||
args.logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
if args.adv:
|
if args.adv:
|
||||||
args.logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
|
logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
args.logger.add_scalars('loss/epoch/train/adv/D', {
|
logger.add_scalars('loss/epoch/train/adv/D', {
|
||||||
'total': epoch_loss[2],
|
'total': epoch_loss[2],
|
||||||
'fake': epoch_loss[3],
|
'fake': epoch_loss[3],
|
||||||
'real': epoch_loss[4],
|
'real': epoch_loss[4],
|
||||||
}, global_step=epoch+1)
|
}, global_step=epoch+1)
|
||||||
|
|
||||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
||||||
args.logger.add_figure('fig/epoch/train/in',
|
logger.add_figure('fig/epoch/train/in',
|
||||||
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
|
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
|
||||||
args.logger.add_figure('fig/epoch/train/out',
|
logger.add_figure('fig/epoch/train/out',
|
||||||
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
||||||
global_step =epoch+1)
|
global_step =epoch+1)
|
||||||
@ -353,19 +375,23 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
return epoch_loss
|
return epoch_loss
|
||||||
|
|
||||||
|
|
||||||
def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||||
|
logger, device, args):
|
||||||
model.eval()
|
model.eval()
|
||||||
if args.adv:
|
if args.adv:
|
||||||
adv_model.eval()
|
adv_model.eval()
|
||||||
|
|
||||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
rank = dist.get_rank()
|
||||||
fake = torch.zeros(1, dtype=torch.float32, device=args.device)
|
world_size = dist.get_world_size()
|
||||||
real = torch.ones(1, dtype=torch.float32, device=args.device)
|
|
||||||
|
epoch_loss = torch.zeros(5, dtype=torch.float64, device=device)
|
||||||
|
fake = torch.zeros(1, dtype=torch.float32, device=device)
|
||||||
|
real = torch.ones(1, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for input, target in loader:
|
for input, target in loader:
|
||||||
input = input.to(args.device, non_blocking=True)
|
input = input.to(device, non_blocking=True)
|
||||||
target = target.to(args.device, non_blocking=True)
|
target = target.to(device, non_blocking=True)
|
||||||
|
|
||||||
output = model(input)
|
output = model(input)
|
||||||
if args.noise_chan > 0:
|
if args.noise_chan > 0:
|
||||||
@ -398,23 +424,23 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
|||||||
epoch_loss[1] += loss_adv.item()
|
epoch_loss[1] += loss_adv.item()
|
||||||
|
|
||||||
dist.all_reduce(epoch_loss)
|
dist.all_reduce(epoch_loss)
|
||||||
epoch_loss /= len(loader) * args.world_size
|
epoch_loss /= len(loader) * world_size
|
||||||
if args.rank == 0:
|
if rank == 0:
|
||||||
args.logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
if args.adv:
|
if args.adv:
|
||||||
args.logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
|
logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
|
||||||
global_step=epoch+1)
|
global_step=epoch+1)
|
||||||
args.logger.add_scalars('loss/epoch/val/adv/D', {
|
logger.add_scalars('loss/epoch/val/adv/D', {
|
||||||
'total': epoch_loss[2],
|
'total': epoch_loss[2],
|
||||||
'fake': epoch_loss[3],
|
'fake': epoch_loss[3],
|
||||||
'real': epoch_loss[4],
|
'real': epoch_loss[4],
|
||||||
}, global_step=epoch+1)
|
}, global_step=epoch+1)
|
||||||
|
|
||||||
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
skip_chan = sum(args.in_chan) if args.adv and args.cgan else 0
|
||||||
args.logger.add_figure('fig/epoch/val/in',
|
logger.add_figure('fig/epoch/val/in',
|
||||||
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
|
fig3d(narrow_like(input, output)[-1]), global_step =epoch+1)
|
||||||
args.logger.add_figure('fig/epoch/val',
|
logger.add_figure('fig/epoch/val',
|
||||||
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
fig3d(output[-1, skip_chan:], target[-1, skip_chan:],
|
||||||
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
output[-1, skip_chan:] - target[-1, skip_chan:]),
|
||||||
global_step =epoch+1)
|
global_step =epoch+1)
|
||||||
|
Loading…
Reference in New Issue
Block a user