Add noise channels to the input
This commit is contained in:
parent
31ea70fca9
commit
0721301113
5 changed files with 24 additions and 6 deletions
|
@ -28,6 +28,9 @@ def add_common_args(parser):
|
|||
parser.add_argument('--scale-factor', default=1, type=int,
|
||||
help='input upsampling factor for super-resolution purpose, in '
|
||||
'which case crop and pad will be taken at the original resolution')
|
||||
parser.add_argument('--noise-chan', default=0, type=int,
|
||||
help='input noise channels to produce the output stochasticity, '
|
||||
'if the input does not completely determines the output')
|
||||
|
||||
parser.add_argument('--model', required=True, type=str,
|
||||
help='model from .models')
|
||||
|
|
|
@ -27,11 +27,13 @@ class FieldDataset(Dataset):
|
|||
in which case `crop`, `pad`, and other spatial attributes will be taken
|
||||
at the original resolution.
|
||||
|
||||
Noise channels can be concatenated to the input.
|
||||
|
||||
`cache` enables data caching.
|
||||
`div_data` enables data division, useful when combined with caching.
|
||||
"""
|
||||
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
||||
augment=False, crop=None, pad=0, scale_factor=1,
|
||||
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
|
||||
cache=False, div_data=False, rank=None, world_size=None,
|
||||
**kwargs):
|
||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||
|
@ -88,6 +90,10 @@ class FieldDataset(Dataset):
|
|||
"only support integer upsampling"
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
assert isinstance(noise_chan, int) and noise_chan >= 0, \
|
||||
"only support integer noise channels"
|
||||
self.noise_chan = noise_chan
|
||||
|
||||
self.cache = cache
|
||||
if self.cache:
|
||||
self.in_fields = {}
|
||||
|
@ -139,6 +145,11 @@ class FieldDataset(Dataset):
|
|||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||
norm(x)
|
||||
|
||||
if self.noise_chan > 0:
|
||||
in_fields.append(
|
||||
torch.randn((self.noise_chan,) + in_fields[0].shape[1:],
|
||||
dtype=torch.float32))
|
||||
|
||||
in_fields = torch.cat(in_fields, dim=0)
|
||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||
|
||||
|
|
|
@ -8,10 +8,10 @@ class PatchGAN(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.convs = nn.Sequential(
|
||||
ConvBlock(in_chan, 64, seq='CA'),
|
||||
ConvBlock(in_chan, 32, seq='CA'),
|
||||
ConvBlock(32, 64, seq='CBA'),
|
||||
ConvBlock(64, 128, seq='CBA'),
|
||||
ConvBlock(128, 256, seq='CBA'),
|
||||
nn.Conv3d(256, out_chan, 1)
|
||||
nn.Conv3d(128, out_chan, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -26,7 +26,7 @@ def test(args):
|
|||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||
|
||||
model = getattr(models, args.model)
|
||||
model = model(sum(in_chan), sum(out_chan))
|
||||
model = model(sum(in_chan) + args.noise_chan, sum(out_chan))
|
||||
criterion = getattr(torch.nn, args.criterion)
|
||||
criterion = criterion()
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ def gpu_worker(local_rank, args):
|
|||
in_chan, out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||
|
||||
model = getattr(models, args.model)
|
||||
model = model(sum(in_chan), sum(out_chan))
|
||||
model = model(sum(in_chan) + args.noise_chan, sum(out_chan))
|
||||
model.to(args.device)
|
||||
model = DistributedDataParallel(model, device_ids=[args.device],
|
||||
process_group=dist.new_group())
|
||||
|
@ -247,6 +247,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||
|
||||
# generator adversarial loss
|
||||
if args.adv:
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
if args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
|
@ -340,6 +342,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
|||
epoch_loss[0] += loss.item()
|
||||
|
||||
if args.adv:
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
if args.cgan:
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
|
|
Loading…
Reference in a new issue