Add weight initialization
This commit is contained in:
parent
75b1c19dcd
commit
9567db7332
@ -80,7 +80,7 @@ def add_train_args(parser):
|
||||
help='epoch to start adversarial training')
|
||||
parser.add_argument('--adv-label-smoothing', default=1, type=float,
|
||||
help='label of real samples for discriminator, '
|
||||
'e.g. 0.9 for label smoothing')
|
||||
'e.g. 0.9 for label smoothing and 1 to disable')
|
||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||
help='final fraction of loss (vs adv-loss)')
|
||||
parser.add_argument('--loss-halflife', default=20, type=float,
|
||||
@ -100,6 +100,9 @@ def add_train_args(parser):
|
||||
help='adversary weight decay')
|
||||
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
||||
help='Enable ReduceLROnPlateau learning rate scheduler')
|
||||
parser.add_argument('--init-weight-scale', type=float,
|
||||
help='weight initialization scale, default is 0.02 with adversary '
|
||||
'and the pytorch default without it')
|
||||
parser.add_argument('--epochs', default=128, type=int,
|
||||
help='total number of epochs to run')
|
||||
parser.add_argument('--seed', default=42, type=int,
|
||||
|
@ -13,7 +13,8 @@ def add_spectral_norm(module):
|
||||
|
||||
def rm_spectral_norm(module):
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, (nn._ConvNd, nn.Linear)):
|
||||
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
setattr(module, name, remove_spectral_norm(child))
|
||||
else:
|
||||
rm_spectral_norm(child)
|
||||
|
@ -11,6 +11,7 @@ from .state import load_model_state_dict
|
||||
|
||||
def test(args):
|
||||
pprint(vars(args))
|
||||
sys.stdout.flush()
|
||||
|
||||
test_dataset = FieldDataset(
|
||||
in_patterns=args.test_in_patterns,
|
||||
|
@ -5,6 +5,7 @@ import time
|
||||
import sys
|
||||
from pprint import pprint
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from torch.multiprocessing import spawn
|
||||
@ -113,7 +114,7 @@ def gpu_worker(local_rank, node, args):
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
||||
criterion = getattr(torch.nn, args.criterion)
|
||||
criterion = getattr(nn, args.criterion)
|
||||
criterion = criterion()
|
||||
criterion.to(device)
|
||||
|
||||
@ -140,7 +141,7 @@ def gpu_worker(local_rank, node, args):
|
||||
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
|
||||
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
||||
adv_criterion = getattr(nn, args.adv_criterion)
|
||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
||||
adv_criterion.to(device)
|
||||
@ -178,15 +179,22 @@ def gpu_worker(local_rank, node, args):
|
||||
|
||||
del state
|
||||
else:
|
||||
# def init_weights(m):
|
||||
# classname = m.__class__.__name__
|
||||
# if isinstance(m, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
|
||||
# m.weight.data.normal_(0.0, 0.02)
|
||||
# elif isinstance(m, torch.nn.BatchNorm3d):
|
||||
# m.weight.data.normal_(1.0, 0.02)
|
||||
# m.bias.data.fill_(0)
|
||||
# model.apply(init_weights)
|
||||
#
|
||||
def init_weights(m):
|
||||
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||
m.weight.data.normal_(0.0, args.init_weight_scale)
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||
if m.affine:
|
||||
# NOTE: dispersion from DCGAN, why?
|
||||
m.weight.data.normal_(1.0, args.init_weight_scale)
|
||||
m.bias.data.fill_(0)
|
||||
if args.init_weight_scale is not None:
|
||||
model.apply(init_weights)
|
||||
if args.adv:
|
||||
adv_model.apply(init_weights)
|
||||
|
||||
start_epoch = 0
|
||||
|
||||
if rank == 0:
|
||||
@ -455,6 +463,9 @@ def set_runtime_default_args(args):
|
||||
if args.adv_weight_decay is None:
|
||||
args.adv_weight_decay = args.weight_decay
|
||||
|
||||
if args.init_weight_scale is None:
|
||||
args.init_weight_scale = 0.02
|
||||
|
||||
|
||||
def dist_init(rank, args):
|
||||
dist_file = 'dist_addr'
|
||||
|
Loading…
Reference in New Issue
Block a user