Remove noise channels
This commit is contained in:
parent
3e13202fb5
commit
f442dd59ba
@ -37,9 +37,6 @@ def add_common_args(parser):
|
||||
parser.add_argument('--scale-factor', default=1, type=int,
|
||||
help='input upsampling factor for super-resolution purpose, in '
|
||||
'which case crop and pad will be taken at the original resolution')
|
||||
parser.add_argument('--noise-chan', default=0, type=int,
|
||||
help='input noise channels to produce the output stochasticity, '
|
||||
'if the input does not completely determines the output')
|
||||
|
||||
parser.add_argument('--model', required=True, type=str,
|
||||
help='model from .models')
|
||||
|
@ -27,8 +27,6 @@ class FieldDataset(Dataset):
|
||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||
the input resolution.
|
||||
|
||||
Noise channels can be concatenated to the input.
|
||||
|
||||
`cache` enables data caching.
|
||||
`div_data` enables data division, useful when combined with caching.
|
||||
"""
|
||||
@ -142,11 +140,6 @@ class FieldDataset(Dataset):
|
||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||
norm(x)
|
||||
|
||||
if self.noise_chan > 0:
|
||||
in_fields.append(
|
||||
torch.randn((self.noise_chan,) + in_fields[0].shape[1:],
|
||||
dtype=torch.float32))
|
||||
|
||||
in_fields = torch.cat(in_fields, dim=0)
|
||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||
|
||||
|
@ -22,7 +22,6 @@ def test(args):
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
noise_chan=args.noise_chan,
|
||||
cache=args.cache,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
@ -35,7 +34,7 @@ def test(args):
|
||||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||
|
||||
model = getattr(models, args.model)
|
||||
model = model(sum(in_chan) + args.noise_chan, sum(out_chan))
|
||||
model = model(sum(in_chan), sum(out_chan))
|
||||
criterion = getattr(torch.nn, args.criterion)
|
||||
criterion = criterion()
|
||||
|
||||
|
@ -66,7 +66,6 @@ def gpu_worker(local_rank, node, args):
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
noise_chan=args.noise_chan,
|
||||
cache=args.cache,
|
||||
div_data=args.div_data,
|
||||
rank=rank,
|
||||
@ -96,7 +95,6 @@ def gpu_worker(local_rank, node, args):
|
||||
crop=args.crop,
|
||||
pad=args.pad,
|
||||
scale_factor=args.scale_factor,
|
||||
noise_chan=args.noise_chan,
|
||||
cache=args.cache,
|
||||
div_data=args.div_data,
|
||||
rank=rank,
|
||||
@ -119,7 +117,7 @@ def gpu_worker(local_rank, node, args):
|
||||
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||
|
||||
model = getattr(models, args.model)
|
||||
model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan))
|
||||
model = model(sum(args.in_chan), sum(args.out_chan))
|
||||
model.to(device)
|
||||
model = DistributedDataParallel(model, device_ids=[device],
|
||||
process_group=dist.new_group())
|
||||
@ -306,8 +304,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
output = model(input)
|
||||
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
@ -450,8 +446,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
||||
output = model(input)
|
||||
|
||||
target = narrow_like(target, output) # FIXME pad
|
||||
if args.noise_chan > 0:
|
||||
input = input[:, :-args.noise_chan] # remove noise channels
|
||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||
input = F.interpolate(input,
|
||||
scale_factor=model.scale_factor, mode='nearest')
|
||||
|
Loading…
Reference in New Issue
Block a user