Add optional adversary model and make validation optional

This commit is contained in:
Yin Li 2020-01-09 20:24:46 -05:00
parent 9cf97b3ac1
commit 15384dc9bd
4 changed files with 262 additions and 60 deletions

View File

@ -17,17 +17,19 @@ def get_args():
def add_common_args(parser): def add_common_args(parser):
parser.add_argument('--norms', type=str_list, help='comma-sep. list ' parser.add_argument('--norms', type=str_list, help='comma-sep. list '
'of normalization functions from data.norms') 'of normalization functions from .data.norms')
parser.add_argument('--crop', type=int, parser.add_argument('--crop', type=int,
help='size to crop the input and target data') help='size to crop the input and target data')
parser.add_argument('--pad', default=0, type=int, parser.add_argument('--pad', default=0, type=int,
help='pad the input data assuming periodic boundary condition') help='size to pad the input data beyond the crop size, assuming '
'periodic boundary condition')
parser.add_argument('--model', required=True, help='model from models') parser.add_argument('--model', required=True, type=str,
parser.add_argument('--criterion', default='MSELoss', help='model from .models')
parser.add_argument('--criterion', default='MSELoss', type=str,
help='model criterion from torch.nn') help='model criterion from torch.nn')
parser.add_argument('--load-state', default='', type=str, parser.add_argument('--load-state', default='', type=str,
help='path to load model, optimizer, rng, etc.') help='path to load the states of model, optimizer, rng, etc.')
parser.add_argument('--batches', default=1, type=int, parser.add_argument('--batches', default=1, type=int,
help='mini-batch size, per GPU in training or in total in testing') help='mini-batch size, per GPU in training or in total in testing')
@ -46,16 +48,23 @@ def add_train_args(parser):
help='comma-sep. list of glob patterns for training input data') help='comma-sep. list of glob patterns for training input data')
parser.add_argument('--train-tgt-patterns', type=str_list, required=True, parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
help='comma-sep. list of glob patterns for training target data') help='comma-sep. list of glob patterns for training target data')
parser.add_argument('--val-in-patterns', type=str_list, required=True, parser.add_argument('--val-in-patterns', type=str_list,
help='comma-sep. list of glob patterns for validation input data') help='comma-sep. list of glob patterns for validation input data')
parser.add_argument('--val-tgt-patterns', type=str_list, required=True, parser.add_argument('--val-tgt-patterns', type=str_list,
help='comma-sep. list of glob patterns for validation target data') help='comma-sep. list of glob patterns for validation target data')
parser.add_argument('--augment', action='store_true', parser.add_argument('--augment', action='store_true',
help='enable training data augmentation') help='enable training data augmentation')
parser.add_argument('--adv-model', type=str,
help='enable adversary model from .models')
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
help='adversary criterion from torch.nn')
parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN')
parser.add_argument('--epochs', default=128, type=int, parser.add_argument('--epochs', default=128, type=int,
help='total number of epochs to run') help='total number of epochs to run')
parser.add_argument('--optimizer', default='Adam', parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim') help='optimizer from torch.optim')
parser.add_argument('--lr', default=0.001, type=float, parser.add_argument('--lr', default=0.001, type=float,
help='initial learning rate') help='initial learning rate')
@ -63,6 +72,10 @@ def add_train_args(parser):
# help='momentum') # help='momentum')
parser.add_argument('--weight-decay', default=0., type=float, parser.add_argument('--weight-decay', default=0., type=float,
help='weight decay') help='weight decay')
parser.add_argument('--adv-lr', type=float,
help='initial adversary learning rate')
parser.add_argument('--adv-weight-decay', type=float,
help='adversary weight decay')
parser.add_argument('--seed', default=42, type=int, parser.add_argument('--seed', default=42, type=int,
help='seed for initializing training') help='seed for initializing training')
@ -70,7 +83,7 @@ def add_train_args(parser):
help='enable data division among GPUs, useful with cache') help='enable data division among GPUs, useful with cache')
parser.add_argument('--dist-backend', default='nccl', type=str, parser.add_argument('--dist-backend', default='nccl', type=str,
choices=['gloo', 'nccl'], help='distributed backend') choices=['gloo', 'nccl'], help='distributed backend')
parser.add_argument('--log-interval', default=20, type=int, parser.add_argument('--log-interval', default=100, type=int,
help='interval between logging training loss') help='interval between logging training loss')

View File

@ -37,6 +37,8 @@ class FieldDataset(Dataset):
assert len(self.in_files) == len(self.tgt_files), \ assert len(self.in_files) == len(self.tgt_files), \
'input and target sample sizes do not match' 'input and target sample sizes do not match'
assert len(self.in_files) > 0, 'file not found'
if div_data: if div_data:
files = len(self.in_files) // world_size files = len(self.in_files) // world_size
self.in_files = self.in_files[rank * files : (rank + 1) * files] self.in_files = self.in_files[rank * files : (rank + 1) * files]
@ -151,7 +153,7 @@ def flip(fields, axes, ndim):
new_fields = [] new_fields = []
for x in fields: for x in fields:
if x.size(0) == ndim: # flip vector components if x.shape[0] == ndim: # flip vector components
x[axes] = - x[axes] x[axes] = - x[axes]
axes = (1 + axes).tolist() axes = (1 + axes).tolist()
@ -167,7 +169,7 @@ def perm(fields, axes, ndim):
new_fields = [] new_fields = []
for x in fields: for x in fields:
if x.size(0) == ndim: # permutate vector components if x.shape[0] == ndim: # permutate vector components
x = x[axes] x = x[axes]
axes = [0] + (1 + axes).tolist() axes = [0] + (1 + axes).tolist()

View File

@ -39,7 +39,9 @@ class ConvBlock(nn.Module):
in_channels, out_channels = self._setup_conv() in_channels, out_channels = self._setup_conv()
return nn.Conv3d(in_channels, out_channels, self.kernel_size) return nn.Conv3d(in_channels, out_channels, self.kernel_size)
elif l == 'B': elif l == 'B':
return nn.BatchNorm3d(self.bn_channels) #return nn.BatchNorm3d(self.bn_channels)
#return nn.InstanceNorm3d(self.bn_channels, affine=True, track_running_stats=True)
return nn.InstanceNorm3d(self.bn_channels)
elif l == 'A': elif l == 'A':
return Swish() return Swish()
else: else:
@ -109,8 +111,8 @@ def narrow_like(a, b):
Try to be symmetric but cut more on the right for odd difference, Try to be symmetric but cut more on the right for odd difference,
consistent with the downsampling. consistent with the downsampling.
""" """
for dim in range(2, 5): for d in range(2, a.dim()):
width = a.size(dim) - b.size(dim) width = a.shape[d] - b.shape[d]
half_width = width // 2 half_width = width // 2
a = a.narrow(dim, half_width, a.size(dim) - width) a = a.narrow(d, half_width, a.shape[d] - width)
return a return a

View File

@ -1,8 +1,9 @@
import os import os
import shutil import shutil
import torch import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.multiprocessing import spawn from torch.multiprocessing import spawn
from torch.distributed import init_process_group, destroy_process_group, all_reduce
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -35,7 +36,7 @@ def gpu_worker(local_rank, args):
args.rank = args.gpus_per_node * args.node + local_rank args.rank = args.gpus_per_node * args.node + local_rank
init_process_group( dist.init_process_group(
backend=args.dist_backend, backend=args.dist_backend,
init_method='env://', init_method='env://',
world_size=args.world_size, world_size=args.world_size,
@ -59,11 +60,14 @@ def gpu_worker(local_rank, args):
pin_memory=True pin_memory=True
) )
args.val = args.val_in_patterns is not None and \
args.val_tgt_patterns is not None
if args.val:
val_dataset = FieldDataset( val_dataset = FieldDataset(
in_patterns=args.val_in_patterns, in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns, tgt_patterns=args.val_tgt_patterns,
augment=False, augment=False,
**{k:v for k, v in vars(args).items() if k != 'augment'}, **{k: v for k, v in vars(args).items() if k != 'augment'},
) )
if not args.div_data: if not args.div_data:
#val_sampler = DistributedSampler(val_dataset, shuffle=False) #val_sampler = DistributedSampler(val_dataset, shuffle=False)
@ -93,17 +97,50 @@ def gpu_worker(local_rank, args):
model.parameters(), model.parameters(),
lr=args.lr, lr=args.lr,
#momentum=args.momentum, #momentum=args.momentum,
betas=(0.5, 0.999),
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
) )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=0.5, patience=3, verbose=True) factor=0.5, patience=3, verbose=True)
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
args.adv = args.adv_model is not None
if args.adv:
adv_model = getattr(models, args.adv_model)
adv_model = adv_model(in_channels + out_channels
if args.cgan else out_channels, 1)
adv_model.to(args.device)
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device])
adv_criterion = getattr(torch.nn, args.adv_criterion)
adv_criterion = adv_criterion()
adv_criterion.to(args.device)
if args.adv_lr is None:
args.adv_lr = args.lr
if args.adv_weight_decay is None:
args.adv_weight_decay = args.weight_decay
adv_optimizer = getattr(torch.optim, args.optimizer)
adv_optimizer = adv_optimizer(
adv_model.parameters(),
lr=args.adv_lr,
betas=(0.5, 0.999),
weight_decay=args.adv_weight_decay,
)
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
factor=0.5, patience=3, verbose=True)
if args.load_state: if args.load_state:
state = torch.load(args.load_state, map_location=args.device) state = torch.load(args.load_state, map_location=args.device)
args.start_epoch = state['epoch'] args.start_epoch = state['epoch']
model.module.load_state_dict(state['model']) model.module.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer']) optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(state['scheduler']) scheduler.load_state_dict(state['scheduler'])
if 'adv_model' in state and args.adv:
adv_model.module.load_state_dict(state['adv_model'])
adv_optimizer.load_state_dict(state['adv_optimizer'])
adv_scheduler.load_state_dict(state['adv_scheduler'])
torch.set_rng_state(state['rng'].cpu()) # move rng state back torch.set_rng_state(state['rng'].cpu()) # move rng state back
if args.rank == 0: if args.rank == 0:
min_loss = state['min_loss'] min_loss = state['min_loss']
@ -111,6 +148,15 @@ def gpu_worker(local_rank, args):
state['epoch'], args.load_state)) state['epoch'], args.load_state))
del state del state
else: else:
# def init_weights(m):
# classname = m.__class__.__name__
# if isinstance(m, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
# m.weight.data.normal_(0.0, 0.02)
# elif isinstance(m, torch.nn.BatchNorm3d):
# m.weight.data.normal_(1.0, 0.02)
# m.bias.data.fill_(0)
# model.apply(init_weights)
#
args.start_epoch = 0 args.start_epoch = 0
if args.rank == 0: if args.rank == 0:
min_loss = None min_loss = None
@ -119,47 +165,68 @@ def gpu_worker(local_rank, args):
if args.rank == 0: if args.rank == 0:
args.logger = SummaryWriter() args.logger = SummaryWriter()
#hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor))
# else str(v) for k, v in vars(args).items()}
#args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if not args.div_data: if not args.div_data:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train(epoch, train_loader, model, criterion, optimizer, scheduler, args)
val_loss = validate(epoch, val_loader, model, criterion, args) train_loss = train(epoch, train_loader,
model, criterion, optimizer, scheduler,
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
args)
epoch_loss = train_loss
scheduler.step(val_loss) if args.val:
val_loss = validate(epoch, val_loader,
model, criterion, adv_model, adv_criterion,
args)
epoch_loss = val_loss
scheduler.step(epoch_loss[0])
if args.rank == 0: if args.rank == 0:
print(end='', flush=True) print(end='', flush=True)
args.logger.close() args.logger.close()
is_best = min_loss is None or epoch_loss[0] < min_loss[0]
if is_best:
min_loss = epoch_loss
state = { state = {
'epoch': epoch + 1, 'epoch': epoch + 1,
'model': model.module.state_dict(), 'model': model.module.state_dict(),
'optimizer' : optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'scheduler' : scheduler.state_dict(), 'scheduler': scheduler.state_dict(),
'rng' : torch.get_rng_state(), 'rng': torch.get_rng_state(),
'min_loss': min_loss, 'min_loss': min_loss,
} }
if args.adv:
state.update({
'adv_model': adv_model.module.state_dict(),
'adv_optimizer': adv_optimizer.state_dict(),
'adv_scheduler': adv_scheduler.state_dict(),
})
ckpt_file = 'checkpoint.pth' ckpt_file = 'checkpoint.pth'
best_file = 'best_model_{}.pth' best_file = 'best_model_{}.pth'
torch.save(state, ckpt_file) torch.save(state, ckpt_file)
del state del state
if min_loss is None or val_loss < min_loss: if is_best:
min_loss = val_loss
shutil.copyfile(ckpt_file, best_file.format(epoch + 1)) shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
#if os.path.isfile(best_file.format(epoch)): #if os.path.isfile(best_file.format(epoch)):
# os.remove(best_file.format(epoch)) # os.remove(best_file.format(epoch))
destroy_process_group() dist.destroy_process_group()
def train(epoch, loader, model, criterion, optimizer, scheduler, args): def train(epoch, loader, model, criterion, optimizer, scheduler,
adv_model, adv_criterion, adv_optimizer, adv_scheduler, args):
model.train() model.train()
if args.adv:
adv_model.train()
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
for i, (input, target) in enumerate(loader): for i, (input, target) in enumerate(loader):
input = input.to(args.device, non_blocking=True) input = input.to(args.device, non_blocking=True)
@ -169,40 +236,158 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, args):
target = narrow_like(target, output) # FIXME pad target = narrow_like(target, output) # FIXME pad
loss = criterion(output, target) loss = criterion(output, target)
epoch_loss[0] += loss.item()
if args.adv:
if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='trilinear')
input = narrow_like(input, output)
output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1)
# discriminator
#
# outtgt = torch.cat([output.detach(), target], dim=0)
#
# eval_outtgt = adv_model(outtgt)
#
# fake = torch.zeros(1, dtype=torch.float32, device=args.device)
# fake = fake.expand_as(output.shape[0] + eval_outtgt.shape[1:])
# real = torch.ones(1, dtype=torch.float32, device=args.device)
# real = real.expand_as(target.shape[0] + eval_outtgt.shape[1:])
# fakereal = torch.cat([fake, real], dim=0)
eval_out = adv_model(output.detach())
fake = torch.zeros(1, dtype=torch.float32,
device=args.device).expand_as(eval_out)
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
epoch_loss[3] += adv_loss_fake.item()
eval_tgt = adv_model(target)
real = torch.ones(1, dtype=torch.float32,
device=args.device).expand_as(eval_tgt)
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item()
adv_optimizer.zero_grad()
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
adv_loss.backward()
adv_optimizer.step()
# generator adversarial loss
eval_out = adv_model(output)
loss_adv = adv_criterion(eval_out, real) # FIXME try min
epoch_loss[1] += loss_adv.item()
# loss_fac = loss.item() / (loss.item() + loss_adv.item())
# loss = 0.5 * (loss * (1 + loss_fac) + loss_adv * loss_fac) # FIXME does this work?
loss += 0.001 * (loss_adv - loss_adv.item())
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
#if scheduler is not None: # for batch scheduler
#scheduler.step()
batch = epoch * len(loader) + i + 1 batch = epoch * len(loader) + i + 1
if batch % args.log_interval == 0: if batch % args.log_interval == 0:
all_reduce(loss) dist.all_reduce(loss)
loss /= args.world_size loss /= args.world_size
if args.rank == 0: if args.rank == 0:
args.logger.add_scalar('loss/train', loss.item(), global_step=batch) args.logger.add_scalar('loss/batch/train', loss.item(),
global_step=batch)
if args.adv:
args.logger.add_scalar('loss/batch/train/adv/G',
loss_adv.item(), global_step=batch)
args.logger.add_scalars('loss/batch/train/adv/D', {
'total': adv_loss.item(),
'fake': adv_loss_fake.item(),
'real': adv_loss_real.item(),
}, global_step=batch)
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * args.world_size
if args.rank == 0:
args.logger.add_scalar('loss/epoch/train', epoch_loss[0],
global_step=epoch+1)
if args.adv:
args.logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
global_step=epoch+1)
args.logger.add_scalars('loss/epoch/train/adv/D', {
'total': epoch_loss[2],
'fake': epoch_loss[3],
'real': epoch_loss[4],
}, global_step=epoch+1)
return epoch_loss
def validate(epoch, loader, model, criterion, args): def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
model.eval() model.eval()
if args.adv:
adv_model.eval()
loss = 0 # loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
with torch.no_grad(): with torch.no_grad():
for i, (input, target) in enumerate(loader): for input, target in loader:
input = input.to(args.device, non_blocking=True) input = input.to(args.device, non_blocking=True)
target = target.to(args.device, non_blocking=True) target = target.to(args.device, non_blocking=True)
output = model(input) output = model(input)
target = narrow_like(target, output) # FIXME pad target = narrow_like(target, output) # FIXME pad
loss += criterion(output, target) loss = criterion(output, target)
epoch_loss[0] += loss.item()
all_reduce(loss) if args.adv:
loss /= len(loader) * args.world_size if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input,
scale_factor=model.scale_factor, mode='trilinear')
input = narrow_like(input, output)
output = torch.cat([input, output], dim=1)
target = torch.cat([input, target], dim=1)
# discriminator
eval_out = adv_model(output)
fake = torch.zeros(1, dtype=torch.float32,
device=args.device).expand_as(eval_out) # FIXME criterion wrapper: both D&G; min reduction; expand_as
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
epoch_loss[3] += adv_loss_fake.item()
eval_tgt = adv_model(target)
real = torch.ones(1, dtype=torch.float32,
device=args.device).expand_as(eval_tgt)
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
epoch_loss[2] += adv_loss.item()
# generator adversarial loss
loss_adv = adv_criterion(eval_out, real) # FIXME try min
epoch_loss[1] += loss_adv.item()
dist.all_reduce(epoch_loss)
epoch_loss /= len(loader) * args.world_size
if args.rank == 0: if args.rank == 0:
args.logger.add_scalar('loss/val', loss.item(), global_step=epoch+1) args.logger.add_scalar('loss/epoch/val', epoch_loss[0],
global_step=epoch+1)
if args.adv:
args.logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
global_step=epoch+1)
args.logger.add_scalars('loss/epoch/val/adv/D', {
'total': epoch_loss[2],
'fake': epoch_loss[3],
'real': epoch_loss[4],
}, global_step=epoch+1)
return loss.item() return epoch_loss