Add weight initialization
This commit is contained in:
parent
75b1c19dcd
commit
9567db7332
@ -80,7 +80,7 @@ def add_train_args(parser):
|
|||||||
help='epoch to start adversarial training')
|
help='epoch to start adversarial training')
|
||||||
parser.add_argument('--adv-label-smoothing', default=1, type=float,
|
parser.add_argument('--adv-label-smoothing', default=1, type=float,
|
||||||
help='label of real samples for discriminator, '
|
help='label of real samples for discriminator, '
|
||||||
'e.g. 0.9 for label smoothing')
|
'e.g. 0.9 for label smoothing and 1 to disable')
|
||||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||||
help='final fraction of loss (vs adv-loss)')
|
help='final fraction of loss (vs adv-loss)')
|
||||||
parser.add_argument('--loss-halflife', default=20, type=float,
|
parser.add_argument('--loss-halflife', default=20, type=float,
|
||||||
@ -100,6 +100,9 @@ def add_train_args(parser):
|
|||||||
help='adversary weight decay')
|
help='adversary weight decay')
|
||||||
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
parser.add_argument('--reduce-lr-on-plateau', action='store_true',
|
||||||
help='Enable ReduceLROnPlateau learning rate scheduler')
|
help='Enable ReduceLROnPlateau learning rate scheduler')
|
||||||
|
parser.add_argument('--init-weight-scale', type=float,
|
||||||
|
help='weight initialization scale, default is 0.02 with adversary '
|
||||||
|
'and the pytorch default without it')
|
||||||
parser.add_argument('--epochs', default=128, type=int,
|
parser.add_argument('--epochs', default=128, type=int,
|
||||||
help='total number of epochs to run')
|
help='total number of epochs to run')
|
||||||
parser.add_argument('--seed', default=42, type=int,
|
parser.add_argument('--seed', default=42, type=int,
|
||||||
|
@ -13,7 +13,8 @@ def add_spectral_norm(module):
|
|||||||
|
|
||||||
def rm_spectral_norm(module):
|
def rm_spectral_norm(module):
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
if isinstance(child, (nn._ConvNd, nn.Linear)):
|
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||||
|
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||||
setattr(module, name, remove_spectral_norm(child))
|
setattr(module, name, remove_spectral_norm(child))
|
||||||
else:
|
else:
|
||||||
rm_spectral_norm(child)
|
rm_spectral_norm(child)
|
||||||
|
@ -11,6 +11,7 @@ from .state import load_model_state_dict
|
|||||||
|
|
||||||
def test(args):
|
def test(args):
|
||||||
pprint(vars(args))
|
pprint(vars(args))
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
test_dataset = FieldDataset(
|
test_dataset = FieldDataset(
|
||||||
in_patterns=args.test_in_patterns,
|
in_patterns=args.test_in_patterns,
|
||||||
|
@ -5,6 +5,7 @@ import time
|
|||||||
import sys
|
import sys
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.multiprocessing import spawn
|
from torch.multiprocessing import spawn
|
||||||
@ -113,7 +114,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
criterion = getattr(torch.nn, args.criterion)
|
criterion = getattr(nn, args.criterion)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
criterion.to(device)
|
criterion.to(device)
|
||||||
|
|
||||||
@ -140,7 +141,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
adv_model = DistributedDataParallel(adv_model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
|
|
||||||
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
adv_criterion = getattr(nn, args.adv_criterion)
|
||||||
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
adv_criterion = adv_criterion_wrapper(adv_criterion)
|
||||||
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
adv_criterion = adv_criterion(reduction='min' if args.min_reduction else 'mean')
|
||||||
adv_criterion.to(device)
|
adv_criterion.to(device)
|
||||||
@ -178,15 +179,22 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
del state
|
del state
|
||||||
else:
|
else:
|
||||||
# def init_weights(m):
|
def init_weights(m):
|
||||||
# classname = m.__class__.__name__
|
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
|
||||||
# if isinstance(m, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
|
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
||||||
# m.weight.data.normal_(0.0, 0.02)
|
m.weight.data.normal_(0.0, args.init_weight_scale)
|
||||||
# elif isinstance(m, torch.nn.BatchNorm3d):
|
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
|
||||||
# m.weight.data.normal_(1.0, 0.02)
|
nn.SyncBatchNorm, nn.LayerNorm, nn.GroupNorm,
|
||||||
# m.bias.data.fill_(0)
|
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
||||||
# model.apply(init_weights)
|
if m.affine:
|
||||||
#
|
# NOTE: dispersion from DCGAN, why?
|
||||||
|
m.weight.data.normal_(1.0, args.init_weight_scale)
|
||||||
|
m.bias.data.fill_(0)
|
||||||
|
if args.init_weight_scale is not None:
|
||||||
|
model.apply(init_weights)
|
||||||
|
if args.adv:
|
||||||
|
adv_model.apply(init_weights)
|
||||||
|
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -455,6 +463,9 @@ def set_runtime_default_args(args):
|
|||||||
if args.adv_weight_decay is None:
|
if args.adv_weight_decay is None:
|
||||||
args.adv_weight_decay = args.weight_decay
|
args.adv_weight_decay = args.weight_decay
|
||||||
|
|
||||||
|
if args.init_weight_scale is None:
|
||||||
|
args.init_weight_scale = 0.02
|
||||||
|
|
||||||
|
|
||||||
def dist_init(rank, args):
|
def dist_init(rank, args):
|
||||||
dist_file = 'dist_addr'
|
dist_file = 'dist_addr'
|
||||||
|
Loading…
Reference in New Issue
Block a user