Add weight initialization

This commit is contained in:
Yin Li 2020-03-01 22:31:30 -05:00
parent 75b1c19dcd
commit 9567db7332
4 changed files with 29 additions and 13 deletions

View File

@ -80,7 +80,7 @@ def add_train_args(parser):
help='epoch to start adversarial training') help='epoch to start adversarial training')
parser.add_argument('--adv-label-smoothing', default=1, type=float, parser.add_argument('--adv-label-smoothing', default=1, type=float,
help='label of real samples for discriminator, ' help='label of real samples for discriminator, '
'e.g. 0.9 for label smoothing') 'e.g. 0.9 for label smoothing and 1 to disable')
parser.add_argument('--loss-fraction', default=0.5, type=float, parser.add_argument('--loss-fraction', default=0.5, type=float,
help='final fraction of loss (vs adv-loss)') help='final fraction of loss (vs adv-loss)')
parser.add_argument('--loss-halflife', default=20, type=float, parser.add_argument('--loss-halflife', default=20, type=float,
@ -100,6 +100,9 @@ def add_train_args(parser):
help='adversary weight decay') help='adversary weight decay')
parser.add_argument('--reduce-lr-on-plateau', action='store_true', parser.add_argument('--reduce-lr-on-plateau', action='store_true',
help='Enable ReduceLROnPlateau learning rate scheduler') help='Enable ReduceLROnPlateau learning rate scheduler')
parser.add_argument('--init-weight-scale', type=float,
help='weight initialization scale, default is 0.02 with adversary '
'and the pytorch default without it')
parser.add_argument('--epochs', default=128, type=int, parser.add_argument('--epochs', default=128, type=int,
help='total number of epochs to run') help='total number of epochs to run')
parser.add_argument('--seed', default=42, type=int, parser.add_argument('--seed', default=42, type=int,

View File

@ -13,7 +13,8 @@ def add_spectral_norm(module):
def rm_spectral_norm(module): def rm_spectral_norm(module):
for name, child in module.named_children(): for name, child in module.named_children():
if isinstance(child, (nn._ConvNd, nn.Linear)): if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
setattr(module, name, remove_spectral_norm(child)) setattr(module, name, remove_spectral_norm(child))
else: else:
rm_spectral_norm(child) rm_spectral_norm(child)

View File

@ -11,6 +11,7 @@ from .state import load_model_state_dict
def test(args): def test(args):
pprint(vars(args)) pprint(vars(args))
sys.stdout.flush()
test_dataset = FieldDataset( test_dataset = FieldDataset(
in_patterns=args.test_in_patterns, in_patterns=args.test_in_patterns,

View File

@ -5,6 +5,7 @@ import time
import sys import sys
from pprint import pprint from pprint import pprint
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
from torch.multiprocessing import spawn from torch.multiprocessing import spawn
@ -113,7 +114,7 @@ def gpu_worker(local_rank, node, args):
model = DistributedDataParallel(model, device_ids=[device], model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
criterion = getattr(torch.nn, args.criterion) criterion = getattr(nn, args.criterion)
criterion = criterion() criterion = criterion()
criterion.to(device) criterion.to(device)
@ -140,7 +141,7 @@ def gpu_worker(local_rank, node, args):
adv_model = DistributedDataParallel(adv_model, device_ids=[device], adv_model = DistributedDataParallel(adv_model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
adv_criterion = getattr(torch.nn, args.adv_criterion) adv_criterion = getattr(nn, args.adv_criterion)
adv_criterion = adv_criterion_wrapper(adv_criterion) adv_criterion = adv_criterion_wrapper(adv_criterion)
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean') adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
adv_criterion.to(device) adv_criterion.to(device)
@ -178,15 +179,22 @@ def gpu_worker(local_rank, node, args):
del state del state
else: else:
# def init_weights(m): def init_weights(m):
# classname = m.__class__.__name__ if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
# if isinstance(m, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)): nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
# m.weight.data.normal_(0.0, 0.02) m.weight.data.normal_(0.0, args.init_weight_scale)
# elif isinstance(m, torch.nn.BatchNorm3d): elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
# m.weight.data.normal_(1.0, 0.02) nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
# m.bias.data.fill_(0) nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
# model.apply(init_weights) if m.affine:
# # NOTE: dispersion from DCGAN, why?
m.weight.data.normal_(1.0, args.init_weight_scale)
m.bias.data.fill_(0)
if args.init_weight_scale is not None:
model.apply(init_weights)
if args.adv:
adv_model.apply(init_weights)
start_epoch = 0 start_epoch = 0
if rank == 0: if rank == 0:
@ -455,6 +463,9 @@ def set_runtime_default_args(args):
if args.adv_weight_decay is None: if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay args.adv_weight_decay = args.weight_decay
if args.init_weight_scale is None:
args.init_weight_scale = 0.02
def dist_init(rank, args): def dist_init(rank, args):
dist_file = 'dist_addr' dist_file = 'dist_addr'