Add noise channels to the input

This commit is contained in:
Yin Li 2020-01-27 13:48:32 -05:00
parent 31ea70fca9
commit 0721301113
5 changed files with 24 additions and 6 deletions

View File

@ -28,6 +28,9 @@ def add_common_args(parser):
parser.add_argument('--scale-factor', default=1, type=int, parser.add_argument('--scale-factor', default=1, type=int,
help='input upsampling factor for super-resolution purpose, in ' help='input upsampling factor for super-resolution purpose, in '
'which case crop and pad will be taken at the original resolution') 'which case crop and pad will be taken at the original resolution')
parser.add_argument('--noise-chan', default=0, type=int,
help='input noise channels to produce the output stochasticity, '
'if the input does not completely determines the output')
parser.add_argument('--model', required=True, type=str, parser.add_argument('--model', required=True, type=str,
help='model from .models') help='model from .models')

View File

@ -27,11 +27,13 @@ class FieldDataset(Dataset):
in which case `crop`, `pad`, and other spatial attributes will be taken in which case `crop`, `pad`, and other spatial attributes will be taken
at the original resolution. at the original resolution.
Noise channels can be concatenated to the input.
`cache` enables data caching. `cache` enables data caching.
`div_data` enables data division, useful when combined with caching. `div_data` enables data division, useful when combined with caching.
""" """
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None, def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
augment=False, crop=None, pad=0, scale_factor=1, augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
cache=False, div_data=False, rank=None, world_size=None, cache=False, div_data=False, rank=None, world_size=None,
**kwargs): **kwargs):
in_file_lists = [sorted(glob(p)) for p in in_patterns] in_file_lists = [sorted(glob(p)) for p in in_patterns]
@ -88,6 +90,10 @@ class FieldDataset(Dataset):
"only support integer upsampling" "only support integer upsampling"
self.scale_factor = scale_factor self.scale_factor = scale_factor
assert isinstance(noise_chan, int) and noise_chan >= 0, \
"only support integer noise channels"
self.noise_chan = noise_chan
self.cache = cache self.cache = cache
if self.cache: if self.cache:
self.in_fields = {} self.in_fields = {}
@ -139,6 +145,11 @@ class FieldDataset(Dataset):
for norm, x in zip(self.tgt_norms, tgt_fields): for norm, x in zip(self.tgt_norms, tgt_fields):
norm(x) norm(x)
if self.noise_chan > 0:
in_fields.append(
torch.randn((self.noise_chan,) + in_fields[0].shape[1:],
dtype=torch.float32))
in_fields = torch.cat(in_fields, dim=0) in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0)

View File

@ -8,10 +8,10 @@ class PatchGAN(nn.Module):
super().__init__() super().__init__()
self.convs = nn.Sequential( self.convs = nn.Sequential(
ConvBlock(in_chan, 64, seq='CA'), ConvBlock(in_chan, 32, seq='CA'),
ConvBlock(32, 64, seq='CBA'),
ConvBlock(64, 128, seq='CBA'), ConvBlock(64, 128, seq='CBA'),
ConvBlock(128, 256, seq='CBA'), nn.Conv3d(128, out_chan, 1)
nn.Conv3d(256, out_chan, 1)
) )
def forward(self, x): def forward(self, x):

View File

@ -26,7 +26,7 @@ def test(args):
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
model = getattr(models, args.model) model = getattr(models, args.model)
model = model(sum(in_chan), sum(out_chan)) model = model(sum(in_chan) + args.noise_chan, sum(out_chan))
criterion = getattr(torch.nn, args.criterion) criterion = getattr(torch.nn, args.criterion)
criterion = criterion() criterion = criterion()

View File

@ -84,7 +84,7 @@ def gpu_worker(local_rank, args):
in_chan, out_chan = train_dataset.in_chan, train_dataset.tgt_chan in_chan, out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = getattr(models, args.model) model = getattr(models, args.model)
model = model(sum(in_chan), sum(out_chan)) model = model(sum(in_chan) + args.noise_chan, sum(out_chan))
model.to(args.device) model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device], model = DistributedDataParallel(model, device_ids=[args.device],
process_group=dist.new_group()) process_group=dist.new_group())
@ -247,6 +247,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
# generator adversarial loss # generator adversarial loss
if args.adv: if args.adv:
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if args.cgan: if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input, input = F.interpolate(input,
@ -340,6 +342,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
epoch_loss[0] += loss.item() epoch_loss[0] += loss.item()
if args.adv: if args.adv:
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if args.cgan: if args.cgan:
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input, input = F.interpolate(input,