diff --git a/map2map/args.py b/map2map/args.py index 31c1625..768dd35 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -28,6 +28,9 @@ def add_common_args(parser): parser.add_argument('--scale-factor', default=1, type=int, help='input upsampling factor for super-resolution purpose, in ' 'which case crop and pad will be taken at the original resolution') + parser.add_argument('--noise-chan', default=0, type=int, + help='input noise channels to produce the output stochasticity, ' + 'if the input does not completely determines the output') parser.add_argument('--model', required=True, type=str, help='model from .models') diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 32d13b2..08cde15 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -27,11 +27,13 @@ class FieldDataset(Dataset): in which case `crop`, `pad`, and other spatial attributes will be taken at the original resolution. + Noise channels can be concatenated to the input. + `cache` enables data caching. `div_data` enables data division, useful when combined with caching. """ def __init__(self, in_patterns, tgt_patterns, in_norms=None, tgt_norms=None, - augment=False, crop=None, pad=0, scale_factor=1, + augment=False, crop=None, pad=0, scale_factor=1, noise_chan=0, cache=False, div_data=False, rank=None, world_size=None, **kwargs): in_file_lists = [sorted(glob(p)) for p in in_patterns] @@ -88,6 +90,10 @@ class FieldDataset(Dataset): "only support integer upsampling" self.scale_factor = scale_factor + assert isinstance(noise_chan, int) and noise_chan >= 0, \ + "only support integer noise channels" + self.noise_chan = noise_chan + self.cache = cache if self.cache: self.in_fields = {} @@ -139,6 +145,11 @@ class FieldDataset(Dataset): for norm, x in zip(self.tgt_norms, tgt_fields): norm(x) + if self.noise_chan > 0: + in_fields.append( + torch.randn((self.noise_chan,) + in_fields[0].shape[1:], + dtype=torch.float32)) + in_fields = torch.cat(in_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0) diff --git a/map2map/models/patchgan.py b/map2map/models/patchgan.py index 74a59af..df57afa 100644 --- a/map2map/models/patchgan.py +++ b/map2map/models/patchgan.py @@ -8,10 +8,10 @@ class PatchGAN(nn.Module): super().__init__() self.convs = nn.Sequential( - ConvBlock(in_chan, 64, seq='CA'), + ConvBlock(in_chan, 32, seq='CA'), + ConvBlock(32, 64, seq='CBA'), ConvBlock(64, 128, seq='CBA'), - ConvBlock(128, 256, seq='CBA'), - nn.Conv3d(256, out_chan, 1) + nn.Conv3d(128, out_chan, 1) ) def forward(self, x): diff --git a/map2map/test.py b/map2map/test.py index a7a5dbb..2c086d4 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -26,7 +26,7 @@ def test(args): in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan model = getattr(models, args.model) - model = model(sum(in_chan), sum(out_chan)) + model = model(sum(in_chan) + args.noise_chan, sum(out_chan)) criterion = getattr(torch.nn, args.criterion) criterion = criterion() diff --git a/map2map/train.py b/map2map/train.py index d476a96..75d5378 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -84,7 +84,7 @@ def gpu_worker(local_rank, args): in_chan, out_chan = train_dataset.in_chan, train_dataset.tgt_chan model = getattr(models, args.model) - model = model(sum(in_chan), sum(out_chan)) + model = model(sum(in_chan) + args.noise_chan, sum(out_chan)) model.to(args.device) model = DistributedDataParallel(model, device_ids=[args.device], process_group=dist.new_group()) @@ -247,6 +247,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, # generator adversarial loss if args.adv: + if args.noise_chan > 0: + input = input[:, :-args.noise_chan] # remove noise channels if args.cgan: if hasattr(model, 'scale_factor') and model.scale_factor != 1: input = F.interpolate(input, @@ -340,6 +342,8 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, args): epoch_loss[0] += loss.item() if args.adv: + if args.noise_chan > 0: + input = input[:, :-args.noise_chan] # remove noise channels if args.cgan: if hasattr(model, 'scale_factor') and model.scale_factor != 1: input = F.interpolate(input,