Add optional adversary model and make validation optional
This commit is contained in:
parent
9cf97b3ac1
commit
15384dc9bd
4 changed files with 262 additions and 60 deletions
|
@ -17,17 +17,19 @@ def get_args():
|
|||
|
||||
def add_common_args(parser):
|
||||
parser.add_argument('--norms', type=str_list, help='comma-sep. list '
|
||||
'of normalization functions from data.norms')
|
||||
'of normalization functions from .data.norms')
|
||||
parser.add_argument('--crop', type=int,
|
||||
help='size to crop the input and target data')
|
||||
parser.add_argument('--pad', default=0, type=int,
|
||||
help='pad the input data assuming periodic boundary condition')
|
||||
help='size to pad the input data beyond the crop size, assuming '
|
||||
'periodic boundary condition')
|
||||
|
||||
parser.add_argument('--model', required=True, help='model from models')
|
||||
parser.add_argument('--criterion', default='MSELoss',
|
||||
parser.add_argument('--model', required=True, type=str,
|
||||
help='model from .models')
|
||||
parser.add_argument('--criterion', default='MSELoss', type=str,
|
||||
help='model criterion from torch.nn')
|
||||
parser.add_argument('--load-state', default='', type=str,
|
||||
help='path to load model, optimizer, rng, etc.')
|
||||
help='path to load the states of model, optimizer, rng, etc.')
|
||||
|
||||
parser.add_argument('--batches', default=1, type=int,
|
||||
help='mini-batch size, per GPU in training or in total in testing')
|
||||
|
@ -46,16 +48,23 @@ def add_train_args(parser):
|
|||
help='comma-sep. list of glob patterns for training input data')
|
||||
parser.add_argument('--train-tgt-patterns', type=str_list, required=True,
|
||||
help='comma-sep. list of glob patterns for training target data')
|
||||
parser.add_argument('--val-in-patterns', type=str_list, required=True,
|
||||
parser.add_argument('--val-in-patterns', type=str_list,
|
||||
help='comma-sep. list of glob patterns for validation input data')
|
||||
parser.add_argument('--val-tgt-patterns', type=str_list, required=True,
|
||||
parser.add_argument('--val-tgt-patterns', type=str_list,
|
||||
help='comma-sep. list of glob patterns for validation target data')
|
||||
parser.add_argument('--augment', action='store_true',
|
||||
help='enable training data augmentation')
|
||||
|
||||
parser.add_argument('--adv-model', type=str,
|
||||
help='enable adversary model from .models')
|
||||
parser.add_argument('--adv-criterion', default='BCEWithLogitsLoss', type=str,
|
||||
help='adversary criterion from torch.nn')
|
||||
parser.add_argument('--cgan', action='store_true',
|
||||
help='enable conditional GAN')
|
||||
|
||||
parser.add_argument('--epochs', default=128, type=int,
|
||||
help='total number of epochs to run')
|
||||
parser.add_argument('--optimizer', default='Adam',
|
||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||
help='optimizer from torch.optim')
|
||||
parser.add_argument('--lr', default=0.001, type=float,
|
||||
help='initial learning rate')
|
||||
|
@ -63,6 +72,10 @@ def add_train_args(parser):
|
|||
# help='momentum')
|
||||
parser.add_argument('--weight-decay', default=0., type=float,
|
||||
help='weight decay')
|
||||
parser.add_argument('--adv-lr', type=float,
|
||||
help='initial adversary learning rate')
|
||||
parser.add_argument('--adv-weight-decay', type=float,
|
||||
help='adversary weight decay')
|
||||
parser.add_argument('--seed', default=42, type=int,
|
||||
help='seed for initializing training')
|
||||
|
||||
|
@ -70,7 +83,7 @@ def add_train_args(parser):
|
|||
help='enable data division among GPUs, useful with cache')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
choices=['gloo', 'nccl'], help='distributed backend')
|
||||
parser.add_argument('--log-interval', default=20, type=int,
|
||||
parser.add_argument('--log-interval', default=100, type=int,
|
||||
help='interval between logging training loss')
|
||||
|
||||
|
||||
|
|
|
@ -37,6 +37,8 @@ class FieldDataset(Dataset):
|
|||
assert len(self.in_files) == len(self.tgt_files), \
|
||||
'input and target sample sizes do not match'
|
||||
|
||||
assert len(self.in_files) > 0, 'file not found'
|
||||
|
||||
if div_data:
|
||||
files = len(self.in_files) // world_size
|
||||
self.in_files = self.in_files[rank * files : (rank + 1) * files]
|
||||
|
@ -151,7 +153,7 @@ def flip(fields, axes, ndim):
|
|||
|
||||
new_fields = []
|
||||
for x in fields:
|
||||
if x.size(0) == ndim: # flip vector components
|
||||
if x.shape[0] == ndim: # flip vector components
|
||||
x[axes] = - x[axes]
|
||||
|
||||
axes = (1 + axes).tolist()
|
||||
|
@ -167,7 +169,7 @@ def perm(fields, axes, ndim):
|
|||
|
||||
new_fields = []
|
||||
for x in fields:
|
||||
if x.size(0) == ndim: # permutate vector components
|
||||
if x.shape[0] == ndim: # permutate vector components
|
||||
x = x[axes]
|
||||
|
||||
axes = [0] + (1 + axes).tolist()
|
||||
|
|
|
@ -39,7 +39,9 @@ class ConvBlock(nn.Module):
|
|||
in_channels, out_channels = self._setup_conv()
|
||||
return nn.Conv3d(in_channels, out_channels, self.kernel_size)
|
||||
elif l == 'B':
|
||||
return nn.BatchNorm3d(self.bn_channels)
|
||||
#return nn.BatchNorm3d(self.bn_channels)
|
||||
#return nn.InstanceNorm3d(self.bn_channels, affine=True, track_running_stats=True)
|
||||
return nn.InstanceNorm3d(self.bn_channels)
|
||||
elif l == 'A':
|
||||
return Swish()
|
||||
else:
|
||||
|
@ -109,8 +111,8 @@ def narrow_like(a, b):
|
|||
Try to be symmetric but cut more on the right for odd difference,
|
||||
consistent with the downsampling.
|
||||
"""
|
||||
for dim in range(2, 5):
|
||||
width = a.size(dim) - b.size(dim)
|
||||
for d in range(2, a.dim()):
|
||||
width = a.shape[d] - b.shape[d]
|
||||
half_width = width // 2
|
||||
a = a.narrow(dim, half_width, a.size(dim) - width)
|
||||
a = a.narrow(d, half_width, a.shape[d] - width)
|
||||
return a
|
||||
|
|
275
map2map/train.py
275
map2map/train.py
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
import shutil
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from torch.multiprocessing import spawn
|
||||
from torch.distributed import init_process_group, destroy_process_group, all_reduce
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -35,7 +36,7 @@ def gpu_worker(local_rank, args):
|
|||
|
||||
args.rank = args.gpus_per_node * args.node + local_rank
|
||||
|
||||
init_process_group(
|
||||
dist.init_process_group(
|
||||
backend=args.dist_backend,
|
||||
init_method='env://',
|
||||
world_size=args.world_size,
|
||||
|
@ -59,23 +60,26 @@ def gpu_worker(local_rank, args):
|
|||
pin_memory=True
|
||||
)
|
||||
|
||||
val_dataset = FieldDataset(
|
||||
in_patterns=args.val_in_patterns,
|
||||
tgt_patterns=args.val_tgt_patterns,
|
||||
augment=False,
|
||||
**{k:v for k, v in vars(args).items() if k != 'augment'},
|
||||
)
|
||||
if not args.div_data:
|
||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
val_sampler = DistributedSampler(val_dataset)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batches,
|
||||
shuffle=False,
|
||||
sampler=None if args.div_data else val_sampler,
|
||||
num_workers=args.loader_workers,
|
||||
pin_memory=True
|
||||
)
|
||||
args.val = args.val_in_patterns is not None and \
|
||||
args.val_tgt_patterns is not None
|
||||
if args.val:
|
||||
val_dataset = FieldDataset(
|
||||
in_patterns=args.val_in_patterns,
|
||||
tgt_patterns=args.val_tgt_patterns,
|
||||
augment=False,
|
||||
**{k: v for k, v in vars(args).items() if k != 'augment'},
|
||||
)
|
||||
if not args.div_data:
|
||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
val_sampler = DistributedSampler(val_dataset)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batches,
|
||||
shuffle=False,
|
||||
sampler=None if args.div_data else val_sampler,
|
||||
num_workers=args.loader_workers,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
in_channels, out_channels = train_dataset.channels
|
||||
|
||||
|
@ -93,17 +97,50 @@ def gpu_worker(local_rank, args):
|
|||
model.parameters(),
|
||||
lr=args.lr,
|
||||
#momentum=args.momentum,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||
factor=0.5, patience=3, verbose=True)
|
||||
|
||||
adv_model = adv_criterion = adv_optimizer = adv_scheduler = None
|
||||
args.adv = args.adv_model is not None
|
||||
if args.adv:
|
||||
adv_model = getattr(models, args.adv_model)
|
||||
adv_model = adv_model(in_channels + out_channels
|
||||
if args.cgan else out_channels, 1)
|
||||
adv_model.to(args.device)
|
||||
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device])
|
||||
|
||||
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
||||
adv_criterion = adv_criterion()
|
||||
adv_criterion.to(args.device)
|
||||
|
||||
if args.adv_lr is None:
|
||||
args.adv_lr = args.lr
|
||||
if args.adv_weight_decay is None:
|
||||
args.adv_weight_decay = args.weight_decay
|
||||
|
||||
adv_optimizer = getattr(torch.optim, args.optimizer)
|
||||
adv_optimizer = adv_optimizer(
|
||||
adv_model.parameters(),
|
||||
lr=args.adv_lr,
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=args.adv_weight_decay,
|
||||
)
|
||||
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
||||
factor=0.5, patience=3, verbose=True)
|
||||
|
||||
if args.load_state:
|
||||
state = torch.load(args.load_state, map_location=args.device)
|
||||
args.start_epoch = state['epoch']
|
||||
model.module.load_state_dict(state['model'])
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
scheduler.load_state_dict(state['scheduler'])
|
||||
if 'adv_model' in state and args.adv:
|
||||
adv_model.module.load_state_dict(state['adv_model'])
|
||||
adv_optimizer.load_state_dict(state['adv_optimizer'])
|
||||
adv_scheduler.load_state_dict(state['adv_scheduler'])
|
||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||
if args.rank == 0:
|
||||
min_loss = state['min_loss']
|
||||
|
@ -111,6 +148,15 @@ def gpu_worker(local_rank, args):
|
|||
state['epoch'], args.load_state))
|
||||
del state
|
||||
else:
|
||||
# def init_weights(m):
|
||||
# classname = m.__class__.__name__
|
||||
# if isinstance(m, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
|
||||
# m.weight.data.normal_(0.0, 0.02)
|
||||
# elif isinstance(m, torch.nn.BatchNorm3d):
|
||||
# m.weight.data.normal_(1.0, 0.02)
|
||||
# m.bias.data.fill_(0)
|
||||
# model.apply(init_weights)
|
||||
#
|
||||
args.start_epoch = 0
|
||||
if args.rank == 0:
|
||||
min_loss = None
|
||||
|
@ -119,47 +165,68 @@ def gpu_worker(local_rank, args):
|
|||
|
||||
if args.rank == 0:
|
||||
args.logger = SummaryWriter()
|
||||
#hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor))
|
||||
# else str(v) for k, v in vars(args).items()}
|
||||
#args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if not args.div_data:
|
||||
train_sampler.set_epoch(epoch)
|
||||
train(epoch, train_loader, model, criterion, optimizer, scheduler, args)
|
||||
|
||||
val_loss = validate(epoch, val_loader, model, criterion, args)
|
||||
train_loss = train(epoch, train_loader,
|
||||
model, criterion, optimizer, scheduler,
|
||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler,
|
||||
args)
|
||||
epoch_loss = train_loss
|
||||
|
||||
scheduler.step(val_loss)
|
||||
if args.val:
|
||||
val_loss = validate(epoch, val_loader,
|
||||
model, criterion, adv_model, adv_criterion,
|
||||
args)
|
||||
epoch_loss = val_loss
|
||||
|
||||
scheduler.step(epoch_loss[0])
|
||||
|
||||
if args.rank == 0:
|
||||
print(end='', flush=True)
|
||||
args.logger.close()
|
||||
|
||||
is_best = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||
if is_best:
|
||||
min_loss = epoch_loss
|
||||
|
||||
state = {
|
||||
'epoch': epoch + 1,
|
||||
'model': model.module.state_dict(),
|
||||
'optimizer' : optimizer.state_dict(),
|
||||
'scheduler' : scheduler.state_dict(),
|
||||
'rng' : torch.get_rng_state(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'scheduler': scheduler.state_dict(),
|
||||
'rng': torch.get_rng_state(),
|
||||
'min_loss': min_loss,
|
||||
}
|
||||
if args.adv:
|
||||
state.update({
|
||||
'adv_model': adv_model.module.state_dict(),
|
||||
'adv_optimizer': adv_optimizer.state_dict(),
|
||||
'adv_scheduler': adv_scheduler.state_dict(),
|
||||
})
|
||||
ckpt_file = 'checkpoint.pth'
|
||||
best_file = 'best_model_{}.pth'
|
||||
torch.save(state, ckpt_file)
|
||||
del state
|
||||
|
||||
if min_loss is None or val_loss < min_loss:
|
||||
min_loss = val_loss
|
||||
if is_best:
|
||||
shutil.copyfile(ckpt_file, best_file.format(epoch + 1))
|
||||
#if os.path.isfile(best_file.format(epoch)):
|
||||
# os.remove(best_file.format(epoch))
|
||||
|
||||
destroy_process_group()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def train(epoch, loader, model, criterion, optimizer, scheduler, args):
|
||||
def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
adv_model, adv_criterion, adv_optimizer, adv_scheduler, args):
|
||||
model.train()
|
||||
if args.adv:
|
||||
adv_model.train()
|
||||
|
||||
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
||||
|
||||
for i, (input, target) in enumerate(loader):
|
||||
input = input.to(args.device, non_blocking=True)
|
||||
|
@ -169,40 +236,158 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, args):
|
|||
target = narrow_like(target, output) # FIXME pad
|
||||
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv:
|
||||
if args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='trilinear')
|
||||
input = narrow_like(input, output)
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
# discriminator
|
||||
#
|
||||
# outtgt = torch.cat([output.detach(), target], dim=0)
|
||||
#
|
||||
# eval_outtgt = adv_model(outtgt)
|
||||
#
|
||||
# fake = torch.zeros(1, dtype=torch.float32, device=args.device)
|
||||
# fake = fake.expand_as(output.shape[0] + eval_outtgt.shape[1:])
|
||||
# real = torch.ones(1, dtype=torch.float32, device=args.device)
|
||||
# real = real.expand_as(target.shape[0] + eval_outtgt.shape[1:])
|
||||
# fakereal = torch.cat([fake, real], dim=0)
|
||||
|
||||
eval_out = adv_model(output.detach())
|
||||
fake = torch.zeros(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_out)
|
||||
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
|
||||
eval_tgt = adv_model(target)
|
||||
real = torch.ones(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_tgt)
|
||||
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
adv_optimizer.zero_grad()
|
||||
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
|
||||
adv_loss.backward()
|
||||
adv_optimizer.step()
|
||||
|
||||
# generator adversarial loss
|
||||
|
||||
eval_out = adv_model(output)
|
||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
# loss_fac = loss.item() / (loss.item() + loss_adv.item())
|
||||
# loss = 0.5 * (loss * (1 + loss_fac) + loss_adv * loss_fac) # FIXME does this work?
|
||||
loss += 0.001 * (loss_adv - loss_adv.item())
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
#if scheduler is not None: # for batch scheduler
|
||||
#scheduler.step()
|
||||
|
||||
batch = epoch * len(loader) + i + 1
|
||||
if batch % args.log_interval == 0:
|
||||
all_reduce(loss)
|
||||
dist.all_reduce(loss)
|
||||
loss /= args.world_size
|
||||
if args.rank == 0:
|
||||
args.logger.add_scalar('loss/train', loss.item(), global_step=batch)
|
||||
args.logger.add_scalar('loss/batch/train', loss.item(),
|
||||
global_step=batch)
|
||||
if args.adv:
|
||||
args.logger.add_scalar('loss/batch/train/adv/G',
|
||||
loss_adv.item(), global_step=batch)
|
||||
args.logger.add_scalars('loss/batch/train/adv/D', {
|
||||
'total': adv_loss.item(),
|
||||
'fake': adv_loss_fake.item(),
|
||||
'real': adv_loss_real.item(),
|
||||
}, global_step=batch)
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * args.world_size
|
||||
if args.rank == 0:
|
||||
args.logger.add_scalar('loss/epoch/train', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
if args.adv:
|
||||
args.logger.add_scalar('loss/epoch/train/adv/G', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
args.logger.add_scalars('loss/epoch/train/adv/D', {
|
||||
'total': epoch_loss[2],
|
||||
'fake': epoch_loss[3],
|
||||
'real': epoch_loss[4],
|
||||
}, global_step=epoch+1)
|
||||
|
||||
return epoch_loss
|
||||
|
||||
|
||||
def validate(epoch, loader, model, criterion, args):
|
||||
def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
||||
model.eval()
|
||||
if args.adv:
|
||||
adv_model.eval()
|
||||
|
||||
loss = 0
|
||||
# loss, loss_adv, adv_loss, adv_loss_fake, adv_loss_real
|
||||
epoch_loss = torch.zeros(5, dtype=torch.float64, device=args.device)
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (input, target) in enumerate(loader):
|
||||
for input, target in loader:
|
||||
input = input.to(args.device, non_blocking=True)
|
||||
target = target.to(args.device, non_blocking=True)
|
||||
|
||||
output = model(input)
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
|
||||
loss += criterion(output, target)
|
||||
loss = criterion(output, target)
|
||||
epoch_loss[0] += loss.item()
|
||||
|
||||
all_reduce(loss)
|
||||
loss /= len(loader) * args.world_size
|
||||
if args.adv:
|
||||
if args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='trilinear')
|
||||
input = narrow_like(input, output)
|
||||
output = torch.cat([input, output], dim=1)
|
||||
target = torch.cat([input, target], dim=1)
|
||||
|
||||
# discriminator
|
||||
|
||||
eval_out = adv_model(output)
|
||||
fake = torch.zeros(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_out) # FIXME criterion wrapper: both D&G; min reduction; expand_as
|
||||
adv_loss_fake = adv_criterion(eval_out, fake) # FIXME try min
|
||||
epoch_loss[3] += adv_loss_fake.item()
|
||||
|
||||
eval_tgt = adv_model(target)
|
||||
real = torch.ones(1, dtype=torch.float32,
|
||||
device=args.device).expand_as(eval_tgt)
|
||||
adv_loss_real = adv_criterion(eval_tgt, real) # FIXME try min
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
# generator adversarial loss
|
||||
|
||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
dist.all_reduce(epoch_loss)
|
||||
epoch_loss /= len(loader) * args.world_size
|
||||
if args.rank == 0:
|
||||
args.logger.add_scalar('loss/val', loss.item(), global_step=epoch+1)
|
||||
args.logger.add_scalar('loss/epoch/val', epoch_loss[0],
|
||||
global_step=epoch+1)
|
||||
if args.adv:
|
||||
args.logger.add_scalar('loss/epoch/val/adv/G', epoch_loss[1],
|
||||
global_step=epoch+1)
|
||||
args.logger.add_scalars('loss/epoch/val/adv/D', {
|
||||
'total': epoch_loss[2],
|
||||
'fake': epoch_loss[3],
|
||||
'real': epoch_loss[4],
|
||||
}, global_step=epoch+1)
|
||||
|
||||
return loss.item()
|
||||
return epoch_loss
|
||||
|
|
Loading…
Reference in a new issue