Remove min-reduction; Split grad plots; Disable cgan w/o adv
This commit is contained in:
parent
29ab550032
commit
d0ab596e11
3 changed files with 21 additions and 29 deletions
|
@ -40,12 +40,13 @@ def add_common_args(parser):
|
|||
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
||||
'of target normalization functions from .data.norms')
|
||||
parser.add_argument('--crop', type=int,
|
||||
help='size to crop the input and target data')
|
||||
help='size to crop the input and target data. Default is the '
|
||||
'field size')
|
||||
parser.add_argument('--crop-start', type=int,
|
||||
help='starting point of the first crop. Default is the origin')
|
||||
parser.add_argument('--crop-stop', type=int,
|
||||
help='stopping point of the last crop. Default is the corner '
|
||||
'opposite to the origin')
|
||||
help='stopping point of the last crop. Default is the opposite '
|
||||
'corner to the origin')
|
||||
parser.add_argument('--crop-step', type=int,
|
||||
help='spacing between crops. Default is the crop size')
|
||||
parser.add_argument('--pad', default=0, type=int,
|
||||
|
@ -67,12 +68,12 @@ def add_common_args(parser):
|
|||
help='allow incompatible keys when loading model states',
|
||||
dest='load_state_strict')
|
||||
|
||||
parser.add_argument('--batches', default=1, type=int,
|
||||
parser.add_argument('--batches', type=int, required=True,
|
||||
help='mini-batch size, per GPU in training or in total in testing')
|
||||
parser.add_argument('--loader-workers', type=int,
|
||||
help='number of data loading workers, per GPU in training or '
|
||||
'in total in testing. '
|
||||
'Default is the batch size or 0 for batch size 1')
|
||||
'in total in testing. Default is 0 for single batch, '
|
||||
'otherwise same as the batch size')
|
||||
|
||||
parser.add_argument('--cache', action='store_true',
|
||||
help='enable LRU cache of input and target fields to reduce I/O')
|
||||
|
@ -117,8 +118,6 @@ def add_train_args(parser):
|
|||
help='enable spectral normalization on the adversary model')
|
||||
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
||||
help='adversarial criterion from torch.nn')
|
||||
parser.add_argument('--min-reduction', action='store_true',
|
||||
help='enable minimum reduction in adversarial criterion')
|
||||
parser.add_argument('--cgan', action='store_true',
|
||||
help='enable conditional GAN')
|
||||
parser.add_argument('--adv-start', default=0, type=int,
|
||||
|
@ -217,6 +216,11 @@ def set_train_args(args):
|
|||
if args.adv_weight_decay is None:
|
||||
args.adv_weight_decay = args.weight_decay
|
||||
|
||||
if args.cgan and not args.adv:
|
||||
args.cgan =False
|
||||
warnings.warn('Disabling cgan given adversary is disabled',
|
||||
RuntimeWarning)
|
||||
|
||||
|
||||
def set_test_args(args):
|
||||
set_common_args(args)
|
||||
|
|
|
@ -19,7 +19,6 @@ def adv_criterion_wrapper(module):
|
|||
"""Wrap an adversarial criterion to:
|
||||
* also take lists of Tensors as target, used to split the input Tensor
|
||||
along the batch dimension
|
||||
* enable min reduction on input
|
||||
* expand target shape as that of input
|
||||
* return a list of losses, one for each pair of input and target Tensors
|
||||
"""
|
||||
|
@ -34,19 +33,10 @@ def adv_criterion_wrapper(module):
|
|||
input = self.split_input(input, target)
|
||||
assert len(input) == len(target)
|
||||
|
||||
if self.reduction == 'min':
|
||||
input = [torch.min(i).unsqueeze(0) for i in input]
|
||||
|
||||
target = [t.expand_as(i) for i, t in zip(input, target)]
|
||||
|
||||
if self.reduction == 'min':
|
||||
self.reduction = 'mean' # average over batches
|
||||
loss = [super(new_module, self).forward(i, t)
|
||||
for i, t in zip(input, target)]
|
||||
self.reduction = 'min'
|
||||
else:
|
||||
loss = [super(new_module, self).forward(i, t)
|
||||
for i, t in zip(input, target)]
|
||||
loss = [super(new_module, self).forward(i, t)
|
||||
for i, t in zip(input, target)]
|
||||
|
||||
return loss
|
||||
|
||||
|
|
|
@ -176,7 +176,7 @@ def gpu_worker(local_rank, node, args):
|
|||
|
||||
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
|
||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
||||
adv_criterion = adv_criterion()
|
||||
adv_criterion.to(device)
|
||||
|
||||
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||
|
@ -410,19 +410,17 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||
if '.weight' in n)
|
||||
grads = [grads[0], grads[-1]]
|
||||
grads = [g.detach().norm().item() for g in grads]
|
||||
logger.add_scalars('grad', {
|
||||
'first': grads[0],
|
||||
'last': grads[-1],
|
||||
}, global_step=batch)
|
||||
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
||||
logger.add_scalar('grad/last', grads[-1], global_step=batch)
|
||||
if args.adv and epoch >= args.adv_start:
|
||||
grads = list(p.grad for n, p in adv_model.named_parameters()
|
||||
if '.weight' in n)
|
||||
grads = [grads[0], grads[-1]]
|
||||
grads = [g.detach().norm().item() for g in grads]
|
||||
logger.add_scalars('grad/adv', {
|
||||
'first': grads[0],
|
||||
'last': grads[-1],
|
||||
}, global_step=batch)
|
||||
logger.add_scalars('grad/adv/first', grads[0],
|
||||
global_step=batch)
|
||||
logger.add_scalars('grad/adv/last', grads[-1],
|
||||
global_step=batch)
|
||||
|
||||
if args.adv and epoch >= args.adv_start and noise_std > 0:
|
||||
logger.add_scalar('instance_noise', noise_std,
|
||||
|
|
Loading…
Reference in a new issue