Remove noise channels

This commit is contained in:
Yin Li 2020-04-15 13:26:48 -04:00
parent 3e13202fb5
commit f442dd59ba
4 changed files with 2 additions and 19 deletions

View File

@ -37,9 +37,6 @@ def add_common_args(parser):
parser.add_argument('--scale-factor', default=1, type=int, parser.add_argument('--scale-factor', default=1, type=int,
help='input upsampling factor for super-resolution purpose, in ' help='input upsampling factor for super-resolution purpose, in '
'which case crop and pad will be taken at the original resolution') 'which case crop and pad will be taken at the original resolution')
parser.add_argument('--noise-chan', default=0, type=int,
help='input noise channels to produce the output stochasticity, '
'if the input does not completely determines the output')
parser.add_argument('--model', required=True, type=str, parser.add_argument('--model', required=True, type=str,
help='model from .models') help='model from .models')

View File

@ -27,8 +27,6 @@ class FieldDataset(Dataset):
the input for super-resolution, in which case `crop` and `pad` are sizes of the input for super-resolution, in which case `crop` and `pad` are sizes of
the input resolution. the input resolution.
Noise channels can be concatenated to the input.
`cache` enables data caching. `cache` enables data caching.
`div_data` enables data division, useful when combined with caching. `div_data` enables data division, useful when combined with caching.
""" """
@ -142,11 +140,6 @@ class FieldDataset(Dataset):
for norm, x in zip(self.tgt_norms, tgt_fields): for norm, x in zip(self.tgt_norms, tgt_fields):
norm(x) norm(x)
if self.noise_chan > 0:
in_fields.append(
torch.randn((self.noise_chan,) + in_fields[0].shape[1:],
dtype=torch.float32))
in_fields = torch.cat(in_fields, dim=0) in_fields = torch.cat(in_fields, dim=0)
tgt_fields = torch.cat(tgt_fields, dim=0) tgt_fields = torch.cat(tgt_fields, dim=0)

View File

@ -22,7 +22,6 @@ def test(args):
crop=args.crop, crop=args.crop,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
noise_chan=args.noise_chan,
cache=args.cache, cache=args.cache,
) )
test_loader = DataLoader( test_loader = DataLoader(
@ -35,7 +34,7 @@ def test(args):
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
model = getattr(models, args.model) model = getattr(models, args.model)
model = model(sum(in_chan) + args.noise_chan, sum(out_chan)) model = model(sum(in_chan), sum(out_chan))
criterion = getattr(torch.nn, args.criterion) criterion = getattr(torch.nn, args.criterion)
criterion = criterion() criterion = criterion()

View File

@ -66,7 +66,6 @@ def gpu_worker(local_rank, node, args):
crop=args.crop, crop=args.crop,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
noise_chan=args.noise_chan,
cache=args.cache, cache=args.cache,
div_data=args.div_data, div_data=args.div_data,
rank=rank, rank=rank,
@ -96,7 +95,6 @@ def gpu_worker(local_rank, node, args):
crop=args.crop, crop=args.crop,
pad=args.pad, pad=args.pad,
scale_factor=args.scale_factor, scale_factor=args.scale_factor,
noise_chan=args.noise_chan,
cache=args.cache, cache=args.cache,
div_data=args.div_data, div_data=args.div_data,
rank=rank, rank=rank,
@ -119,7 +117,7 @@ def gpu_worker(local_rank, node, args):
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
model = getattr(models, args.model) model = getattr(models, args.model)
model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan)) model = model(sum(args.in_chan), sum(args.out_chan))
model.to(device) model.to(device)
model = DistributedDataParallel(model, device_ids=[device], model = DistributedDataParallel(model, device_ids=[device],
process_group=dist.new_group()) process_group=dist.new_group())
@ -306,8 +304,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
output = model(input) output = model(input)
target = narrow_like(target, output) # FIXME pad target = narrow_like(target, output) # FIXME pad
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input, input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest') scale_factor=model.scale_factor, mode='nearest')
@ -450,8 +446,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
output = model(input) output = model(input)
target = narrow_like(target, output) # FIXME pad target = narrow_like(target, output) # FIXME pad
if args.noise_chan > 0:
input = input[:, :-args.noise_chan] # remove noise channels
if hasattr(model, 'scale_factor') and model.scale_factor != 1: if hasattr(model, 'scale_factor') and model.scale_factor != 1:
input = F.interpolate(input, input = F.interpolate(input,
scale_factor=model.scale_factor, mode='nearest') scale_factor=model.scale_factor, mode='nearest')