Remove min-reduction; Split grad plots; Disable cgan w/o adv

This commit is contained in:
Yin Li 2020-07-13 09:40:15 -04:00
parent 29ab550032
commit d0ab596e11
3 changed files with 21 additions and 29 deletions

View File

@ -40,12 +40,13 @@ def add_common_args(parser):
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list ' parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
'of target normalization functions from .data.norms') 'of target normalization functions from .data.norms')
parser.add_argument('--crop', type=int, parser.add_argument('--crop', type=int,
help='size to crop the input and target data') help='size to crop the input and target data. Default is the '
'field size')
parser.add_argument('--crop-start', type=int, parser.add_argument('--crop-start', type=int,
help='starting point of the first crop. Default is the origin') help='starting point of the first crop. Default is the origin')
parser.add_argument('--crop-stop', type=int, parser.add_argument('--crop-stop', type=int,
help='stopping point of the last crop. Default is the corner ' help='stopping point of the last crop. Default is the opposite '
'opposite to the origin') 'corner to the origin')
parser.add_argument('--crop-step', type=int, parser.add_argument('--crop-step', type=int,
help='spacing between crops. Default is the crop size') help='spacing between crops. Default is the crop size')
parser.add_argument('--pad', default=0, type=int, parser.add_argument('--pad', default=0, type=int,
@ -67,12 +68,12 @@ def add_common_args(parser):
help='allow incompatible keys when loading model states', help='allow incompatible keys when loading model states',
dest='load_state_strict') dest='load_state_strict')
parser.add_argument('--batches', default=1, type=int, parser.add_argument('--batches', type=int, required=True,
help='mini-batch size, per GPU in training or in total in testing') help='mini-batch size, per GPU in training or in total in testing')
parser.add_argument('--loader-workers', type=int, parser.add_argument('--loader-workers', type=int,
help='number of data loading workers, per GPU in training or ' help='number of data loading workers, per GPU in training or '
'in total in testing. ' 'in total in testing. Default is 0 for single batch, '
'Default is the batch size or 0 for batch size 1') 'otherwise same as the batch size')
parser.add_argument('--cache', action='store_true', parser.add_argument('--cache', action='store_true',
help='enable LRU cache of input and target fields to reduce I/O') help='enable LRU cache of input and target fields to reduce I/O')
@ -117,8 +118,6 @@ def add_train_args(parser):
help='enable spectral normalization on the adversary model') help='enable spectral normalization on the adversary model')
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str, parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
help='adversarial criterion from torch.nn') help='adversarial criterion from torch.nn')
parser.add_argument('--min-reduction', action='store_true',
help='enable minimum reduction in adversarial criterion')
parser.add_argument('--cgan', action='store_true', parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN') help='enable conditional GAN')
parser.add_argument('--adv-start', default=0, type=int, parser.add_argument('--adv-start', default=0, type=int,
@ -217,6 +216,11 @@ def set_train_args(args):
if args.adv_weight_decay is None: if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay args.adv_weight_decay = args.weight_decay
if args.cgan and not args.adv:
args.cgan =False
warnings.warn('Disabling cgan given adversary is disabled',
RuntimeWarning)
def set_test_args(args): def set_test_args(args):
set_common_args(args) set_common_args(args)

View File

@ -19,7 +19,6 @@ def adv_criterion_wrapper(module):
"""Wrap an adversarial criterion to: """Wrap an adversarial criterion to:
* also take lists of Tensors as target, used to split the input Tensor * also take lists of Tensors as target, used to split the input Tensor
along the batch dimension along the batch dimension
* enable min reduction on input
* expand target shape as that of input * expand target shape as that of input
* return a list of losses, one for each pair of input and target Tensors * return a list of losses, one for each pair of input and target Tensors
""" """
@ -34,17 +33,8 @@ def adv_criterion_wrapper(module):
input = self.split_input(input, target) input = self.split_input(input, target)
assert len(input) == len(target) assert len(input) == len(target)
if self.reduction == 'min':
input = [torch.min(i).unsqueeze(0) for i in input]
target = [t.expand_as(i) for i, t in zip(input, target)] target = [t.expand_as(i) for i, t in zip(input, target)]
if self.reduction == 'min':
self.reduction = 'mean' # average over batches
loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)]
self.reduction = 'min'
else:
loss = [super(new_module, self).forward(i, t) loss = [super(new_module, self).forward(i, t)
for i, t in zip(input, target)] for i, t in zip(input, target)]

View File

@ -176,7 +176,7 @@ def gpu_worker(local_rank, node, args):
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at) adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
adv_criterion = adv_criterion_wrapper(adv_criterion) adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean') adv_criterion = adv_criterion()
adv_criterion.to(device) adv_criterion.to(device)
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at) adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
@ -410,19 +410,17 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
if '.weight' in n) if '.weight' in n)
grads = [grads[0], grads[-1]] grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads] grads = [g.detach().norm().item() for g in grads]
logger.add_scalars('grad', { logger.add_scalar('grad/first', grads[0], global_step=batch)
'first': grads[0], logger.add_scalar('grad/last', grads[-1], global_step=batch)
'last': grads[-1],
}, global_step=batch)
if args.adv and epoch >= args.adv_start: if args.adv and epoch >= args.adv_start:
grads = list(p.grad for n, p in adv_model.named_parameters() grads = list(p.grad for n, p in adv_model.named_parameters()
if '.weight' in n) if '.weight' in n)
grads = [grads[0], grads[-1]] grads = [grads[0], grads[-1]]
grads = [g.detach().norm().item() for g in grads] grads = [g.detach().norm().item() for g in grads]
logger.add_scalars('grad/adv', { logger.add_scalars('grad/adv/first', grads[0],
'first': grads[0], global_step=batch)
'last': grads[-1], logger.add_scalars('grad/adv/last', grads[-1],
}, global_step=batch) global_step=batch)
if args.adv and epoch >= args.adv_start and noise_std > 0: if args.adv and epoch >= args.adv_start and noise_std > 0:
logger.add_scalar('instance_noise', noise_std, logger.add_scalar('instance_noise', noise_std,