Add noise channels to the input
This commit is contained in:
parent
31ea70fca9
commit
0721301113
@ -28,6 +28,9 @@ def add_common_args(parser):
|
|||||||
parser.add_argument('--scale-factor', default=1, type=int,
|
parser.add_argument('--scale-factor', default=1, type=int,
|
||||||
help='input upsampling factor for super-resolution purpose, in '
|
help='input upsampling factor for super-resolution purpose, in '
|
||||||
'which case crop and pad will be taken at the original resolution')
|
'which case crop and pad will be taken at the original resolution')
|
||||||
|
parser.add_argument('--noise-chan', default=0, type=int,
|
||||||
|
help='input noise channels to produce the output stochasticity, '
|
||||||
|
'if the input does not completely determines the output')
|
||||||
|
|
||||||
parser.add_argument('--model', required=True, type=str,
|
parser.add_argument('--model', required=True, type=str,
|
||||||
help='model from .models')
|
help='model from .models')
|
||||||
|
@ -27,11 +27,13 @@ class FieldDataset(Dataset):
|
|||||||
in which case `crop`, `pad`, and other spatial attributes will be taken
|
in which case `crop`, `pad`, and other spatial attributes will be taken
|
||||||
at the original resolution.
|
at the original resolution.
|
||||||
|
|
||||||
|
Noise channels can be concatenated to the input.
|
||||||
|
|
||||||
`cache` enables data caching.
|
`cache` enables data caching.
|
||||||
`div_data` enables data division, useful when combined with caching.
|
`div_data` enables data division, useful when combined with caching.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None,
|
||||||
augment=False, crop=None, pad=0, scale_factor=1,
|
augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0,
|
||||||
cache=False, div_data=False, rank=None, world_size=None,
|
cache=False, div_data=False, rank=None, world_size=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
in_file_lists = [sorted(glob(p)) for p in in_patterns]
|
||||||
@ -88,6 +90,10 @@ class FieldDataset(Dataset):
|
|||||||
"only support integer upsampling"
|
"only support integer upsampling"
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
assert isinstance(noise_chan, int) and noise_chan >= 0, \
|
||||||
|
"only support integer noise channels"
|
||||||
|
self.noise_chan = noise_chan
|
||||||
|
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
if self.cache:
|
if self.cache:
|
||||||
self.in_fields = {}
|
self.in_fields = {}
|
||||||
@ -139,6 +145,11 @@ class FieldDataset(Dataset):
|
|||||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||||
norm(x)
|
norm(x)
|
||||||
|
|
||||||
|
if self.noise_chan > 0:
|
||||||
|
in_fields.append(
|
||||||
|
torch.randn((self.noise_chan,) + in_fields[0].shape[1:],
|
||||||
|
dtype=torch.float32))
|
||||||
|
|
||||||
in_fields = torch.cat(in_fields, dim=0)
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
|
|
||||||
|
@ -8,10 +8,10 @@ class PatchGAN(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.convs = nn.Sequential(
|
self.convs = nn.Sequential(
|
||||||
ConvBlock(in_chan, 64, seq='CA'),
|
ConvBlock(in_chan, 32, seq='CA'),
|
||||||
|
ConvBlock(32, 64, seq='CBA'),
|
||||||
ConvBlock(64, 128, seq='CBA'),
|
ConvBlock(64, 128, seq='CBA'),
|
||||||
ConvBlock(128, 256, seq='CBA'),
|
nn.Conv3d(128, out_chan, 1)
|
||||||
nn.Conv3d(256, out_chan, 1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -26,7 +26,7 @@ def test(args):
|
|||||||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||||
|
|
||||||
model = getattr(models, args.model)
|
model = getattr(models, args.model)
|
||||||
model = model(sum(in_chan), sum(out_chan))
|
model = model(sum(in_chan) + args.noise_chan, sum(out_chan))
|
||||||
criterion = getattr(torch.nn, args.criterion)
|
criterion = getattr(torch.nn, args.criterion)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ def gpu_worker(local_rank, args):
|
|||||||
in_chan, out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
in_chan, out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||||
|
|
||||||
model = getattr(models, args.model)
|
model = getattr(models, args.model)
|
||||||
model = model(sum(in_chan), sum(out_chan))
|
model = model(sum(in_chan) + args.noise_chan, sum(out_chan))
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
model = DistributedDataParallel(model, device_ids=[args.device],
|
model = DistributedDataParallel(model, device_ids=[args.device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
@ -247,6 +247,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
|
|
||||||
# generator adversarial loss
|
# generator adversarial loss
|
||||||
if args.adv:
|
if args.adv:
|
||||||
|
if args.noise_chan > 0:
|
||||||
|
input = input[:, :-args.noise_chan] # remove noise channels
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
@ -340,6 +342,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args):
|
|||||||
epoch_loss[0] += loss.item()
|
epoch_loss[0] += loss.item()
|
||||||
|
|
||||||
if args.adv:
|
if args.adv:
|
||||||
|
if args.noise_chan > 0:
|
||||||
|
input = input[:, :-args.noise_chan] # remove noise channels
|
||||||
if args.cgan:
|
if args.cgan:
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
|
Loading…
Reference in New Issue
Block a user