diff --git a/map2map/args.py b/map2map/args.py index f685a47..0da57b6 100644 --- a/map2map/args.py +++ b/map2map/args.py @@ -37,9 +37,6 @@ def add_common_args(parser): parser.add_argument('--scale-factor', default=1, type=int, help='input upsampling factor for super-resolution purpose, in ' 'which case crop and pad will be taken at the original resolution') - parser.add_argument('--noise-chan', default=0, type=int, - help='input noise channels to produce the output stochasticity, ' - 'if the input does not completely determines the output') parser.add_argument('--model', required=True, type=str, help='model from .models') diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 092a476..d0d6869 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -27,8 +27,6 @@ class FieldDataset(Dataset): the input for super-resolution, in which case `crop` and `pad` are sizes of the input resolution. - Noise channels can be concatenated to the input. - `cache` enables data caching. `div_data` enables data division, useful when combined with caching. """ @@ -142,11 +140,6 @@ class FieldDataset(Dataset): for norm, x in zip(self.tgt_norms, tgt_fields): norm(x) - if self.noise_chan > 0: - in_fields.append( - torch.randn((self.noise_chan,) + in_fields[0].shape[1:], - dtype=torch.float32)) - in_fields = torch.cat(in_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0) diff --git a/map2map/test.py b/map2map/test.py index efd9e52..f349eb0 100644 --- a/map2map/test.py +++ b/map2map/test.py @@ -22,7 +22,6 @@ def test(args): crop=args.crop, pad=args.pad, scale_factor=args.scale_factor, - noise_chan=args.noise_chan, cache=args.cache, ) test_loader = DataLoader( @@ -35,7 +34,7 @@ def test(args): in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan model = getattr(models, args.model) - model = model(sum(in_chan) + args.noise_chan, sum(out_chan)) + model = model(sum(in_chan), sum(out_chan)) criterion = getattr(torch.nn, args.criterion) criterion = criterion() diff --git a/map2map/train.py b/map2map/train.py index 39a3d9a..ba23a72 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -66,7 +66,6 @@ def gpu_worker(local_rank, node, args): crop=args.crop, pad=args.pad, scale_factor=args.scale_factor, - noise_chan=args.noise_chan, cache=args.cache, div_data=args.div_data, rank=rank, @@ -96,7 +95,6 @@ def gpu_worker(local_rank, node, args): crop=args.crop, pad=args.pad, scale_factor=args.scale_factor, - noise_chan=args.noise_chan, cache=args.cache, div_data=args.div_data, rank=rank, @@ -119,7 +117,7 @@ def gpu_worker(local_rank, node, args): args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan model = getattr(models, args.model) - model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan)) + model = model(sum(args.in_chan), sum(args.out_chan)) model.to(device) model = DistributedDataParallel(model, device_ids=[device], process_group=dist.new_group()) @@ -306,8 +304,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler, output = model(input) target = narrow_like(target, output) # FIXME pad - if args.noise_chan > 0: - input = input[:, :-args.noise_chan] # remove noise channels if hasattr(model, 'scale_factor') and model.scale_factor != 1: input = F.interpolate(input, scale_factor=model.scale_factor, mode='nearest') @@ -450,8 +446,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion, output = model(input) target = narrow_like(target, output) # FIXME pad - if args.noise_chan > 0: - input = input[:, :-args.noise_chan] # remove noise channels if hasattr(model, 'scale_factor') and model.scale_factor != 1: input = F.interpolate(input, scale_factor=model.scale_factor, mode='nearest')