Remove min-reduction; Split grad plots; Disable cgan w/o adv
This commit is contained in:
parent
29ab550032
commit
d0ab596e11
@ -40,12 +40,13 @@ def add_common_args(parser):
|
|||||||
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
parser.add_argument('--tgt-norms', type=str_list, help='comma-sep. list '
|
||||||
'of target normalization functions from .data.norms')
|
'of target normalization functions from .data.norms')
|
||||||
parser.add_argument('--crop', type=int,
|
parser.add_argument('--crop', type=int,
|
||||||
help='size to crop the input and target data')
|
help='size to crop the input and target data. Default is the '
|
||||||
|
'field size')
|
||||||
parser.add_argument('--crop-start', type=int,
|
parser.add_argument('--crop-start', type=int,
|
||||||
help='starting point of the first crop. Default is the origin')
|
help='starting point of the first crop. Default is the origin')
|
||||||
parser.add_argument('--crop-stop', type=int,
|
parser.add_argument('--crop-stop', type=int,
|
||||||
help='stopping point of the last crop. Default is the corner '
|
help='stopping point of the last crop. Default is the opposite '
|
||||||
'opposite to the origin')
|
'corner to the origin')
|
||||||
parser.add_argument('--crop-step', type=int,
|
parser.add_argument('--crop-step', type=int,
|
||||||
help='spacing between crops. Default is the crop size')
|
help='spacing between crops. Default is the crop size')
|
||||||
parser.add_argument('--pad', default=0, type=int,
|
parser.add_argument('--pad', default=0, type=int,
|
||||||
@ -67,12 +68,12 @@ def add_common_args(parser):
|
|||||||
help='allow incompatible keys when loading model states',
|
help='allow incompatible keys when loading model states',
|
||||||
dest='load_state_strict')
|
dest='load_state_strict')
|
||||||
|
|
||||||
parser.add_argument('--batches', default=1, type=int,
|
parser.add_argument('--batches', type=int, required=True,
|
||||||
help='mini-batch size, per GPU in training or in total in testing')
|
help='mini-batch size, per GPU in training or in total in testing')
|
||||||
parser.add_argument('--loader-workers', type=int,
|
parser.add_argument('--loader-workers', type=int,
|
||||||
help='number of data loading workers, per GPU in training or '
|
help='number of data loading workers, per GPU in training or '
|
||||||
'in total in testing. '
|
'in total in testing. Default is 0 for single batch, '
|
||||||
'Default is the batch size or 0 for batch size 1')
|
'otherwise same as the batch size')
|
||||||
|
|
||||||
parser.add_argument('--cache', action='store_true',
|
parser.add_argument('--cache', action='store_true',
|
||||||
help='enable LRU cache of input and target fields to reduce I/O')
|
help='enable LRU cache of input and target fields to reduce I/O')
|
||||||
@ -117,8 +118,6 @@ def add_train_args(parser):
|
|||||||
help='enable spectral normalization on the adversary model')
|
help='enable spectral normalization on the adversary model')
|
||||||
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
||||||
help='adversarial criterion from torch.nn')
|
help='adversarial criterion from torch.nn')
|
||||||
parser.add_argument('--min-reduction', action='store_true',
|
|
||||||
help='enable minimum reduction in adversarial criterion')
|
|
||||||
parser.add_argument('--cgan', action='store_true',
|
parser.add_argument('--cgan', action='store_true',
|
||||||
help='enable conditional GAN')
|
help='enable conditional GAN')
|
||||||
parser.add_argument('--adv-start', default=0, type=int,
|
parser.add_argument('--adv-start', default=0, type=int,
|
||||||
@ -217,6 +216,11 @@ def set_train_args(args):
|
|||||||
if args.adv_weight_decay is None:
|
if args.adv_weight_decay is None:
|
||||||
args.adv_weight_decay = args.weight_decay
|
args.adv_weight_decay = args.weight_decay
|
||||||
|
|
||||||
|
if args.cgan and not args.adv:
|
||||||
|
args.cgan =False
|
||||||
|
warnings.warn('Disabling cgan given adversary is disabled',
|
||||||
|
RuntimeWarning)
|
||||||
|
|
||||||
|
|
||||||
def set_test_args(args):
|
def set_test_args(args):
|
||||||
set_common_args(args)
|
set_common_args(args)
|
||||||
|
@ -19,7 +19,6 @@ def adv_criterion_wrapper(module):
|
|||||||
"""Wrap an adversarial criterion to:
|
"""Wrap an adversarial criterion to:
|
||||||
* also take lists of Tensors as target, used to split the input Tensor
|
* also take lists of Tensors as target, used to split the input Tensor
|
||||||
along the batch dimension
|
along the batch dimension
|
||||||
* enable min reduction on input
|
|
||||||
* expand target shape as that of input
|
* expand target shape as that of input
|
||||||
* return a list of losses, one for each pair of input and target Tensors
|
* return a list of losses, one for each pair of input and target Tensors
|
||||||
"""
|
"""
|
||||||
@ -34,17 +33,8 @@ def adv_criterion_wrapper(module):
|
|||||||
input = self.split_input(input, target)
|
input = self.split_input(input, target)
|
||||||
assert len(input) == len(target)
|
assert len(input) == len(target)
|
||||||
|
|
||||||
if self.reduction == 'min':
|
|
||||||
input = [torch.min(i).unsqueeze(0) for i in input]
|
|
||||||
|
|
||||||
target = [t.expand_as(i) for i, t in zip(input, target)]
|
target = [t.expand_as(i) for i, t in zip(input, target)]
|
||||||
|
|
||||||
if self.reduction == 'min':
|
|
||||||
self.reduction = 'mean' # average over batches
|
|
||||||
loss = [super(new_module, self).forward(i, t)
|
|
||||||
for i, t in zip(input, target)]
|
|
||||||
self.reduction = 'min'
|
|
||||||
else:
|
|
||||||
loss = [super(new_module, self).forward(i, t)
|
loss = [super(new_module, self).forward(i, t)
|
||||||
for i, t in zip(input, target)]
|
for i, t in zip(input, target)]
|
||||||
|
|
||||||
|
@ -176,7 +176,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
|
adv_criterion = import_attr(args.adv_criterion, nn.__name__, args.callback_at)
|
||||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||||
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
adv_criterion = adv_criterion()
|
||||||
adv_criterion.to(device)
|
adv_criterion.to(device)
|
||||||
|
|
||||||
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
adv_optimizer = import_attr(args.optimizer, optim.__name__, args.callback_at)
|
||||||
@ -410,19 +410,17 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
if '.weight' in n)
|
if '.weight' in n)
|
||||||
grads = [grads[0], grads[-1]]
|
grads = [grads[0], grads[-1]]
|
||||||
grads = [g.detach().norm().item() for g in grads]
|
grads = [g.detach().norm().item() for g in grads]
|
||||||
logger.add_scalars('grad', {
|
logger.add_scalar('grad/first', grads[0], global_step=batch)
|
||||||
'first': grads[0],
|
logger.add_scalar('grad/last', grads[-1], global_step=batch)
|
||||||
'last': grads[-1],
|
|
||||||
}, global_step=batch)
|
|
||||||
if args.adv and epoch >= args.adv_start:
|
if args.adv and epoch >= args.adv_start:
|
||||||
grads = list(p.grad for n, p in adv_model.named_parameters()
|
grads = list(p.grad for n, p in adv_model.named_parameters()
|
||||||
if '.weight' in n)
|
if '.weight' in n)
|
||||||
grads = [grads[0], grads[-1]]
|
grads = [grads[0], grads[-1]]
|
||||||
grads = [g.detach().norm().item() for g in grads]
|
grads = [g.detach().norm().item() for g in grads]
|
||||||
logger.add_scalars('grad/adv', {
|
logger.add_scalars('grad/adv/first', grads[0],
|
||||||
'first': grads[0],
|
global_step=batch)
|
||||||
'last': grads[-1],
|
logger.add_scalars('grad/adv/last', grads[-1],
|
||||||
}, global_step=batch)
|
global_step=batch)
|
||||||
|
|
||||||
if args.adv and epoch >= args.adv_start and noise_std > 0:
|
if args.adv and epoch >= args.adv_start and noise_std > 0:
|
||||||
logger.add_scalar('instance_noise', noise_std,
|
logger.add_scalar('instance_noise', noise_std,
|
||||||
|
Loading…
Reference in New Issue
Block a user