Remove noise channels
This commit is contained in:
parent
3e13202fb5
commit
f442dd59ba
@ -37,9 +37,6 @@ def add_common_args(parser):
|
|||||||
parser.add_argument('--scale-factor', default=1, type=int,
|
parser.add_argument('--scale-factor', default=1, type=int,
|
||||||
help='input upsampling factor for super-resolution purpose, in '
|
help='input upsampling factor for super-resolution purpose, in '
|
||||||
'which case crop and pad will be taken at the original resolution')
|
'which case crop and pad will be taken at the original resolution')
|
||||||
parser.add_argument('--noise-chan', default=0, type=int,
|
|
||||||
help='input noise channels to produce the output stochasticity, '
|
|
||||||
'if the input does not completely determines the output')
|
|
||||||
|
|
||||||
parser.add_argument('--model', required=True, type=str,
|
parser.add_argument('--model', required=True, type=str,
|
||||||
help='model from .models')
|
help='model from .models')
|
||||||
|
@ -27,8 +27,6 @@ class FieldDataset(Dataset):
|
|||||||
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
the input for super-resolution, in which case `crop` and `pad` are sizes of
|
||||||
the input resolution.
|
the input resolution.
|
||||||
|
|
||||||
Noise channels can be concatenated to the input.
|
|
||||||
|
|
||||||
`cache` enables data caching.
|
`cache` enables data caching.
|
||||||
`div_data` enables data division, useful when combined with caching.
|
`div_data` enables data division, useful when combined with caching.
|
||||||
"""
|
"""
|
||||||
@ -142,11 +140,6 @@ class FieldDataset(Dataset):
|
|||||||
for norm, x in zip(self.tgt_norms, tgt_fields):
|
for norm, x in zip(self.tgt_norms, tgt_fields):
|
||||||
norm(x)
|
norm(x)
|
||||||
|
|
||||||
if self.noise_chan > 0:
|
|
||||||
in_fields.append(
|
|
||||||
torch.randn((self.noise_chan,) + in_fields[0].shape[1:],
|
|
||||||
dtype=torch.float32))
|
|
||||||
|
|
||||||
in_fields = torch.cat(in_fields, dim=0)
|
in_fields = torch.cat(in_fields, dim=0)
|
||||||
tgt_fields = torch.cat(tgt_fields, dim=0)
|
tgt_fields = torch.cat(tgt_fields, dim=0)
|
||||||
|
|
||||||
|
@ -22,7 +22,6 @@ def test(args):
|
|||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
noise_chan=args.noise_chan,
|
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
)
|
)
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
@ -35,7 +34,7 @@ def test(args):
|
|||||||
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
in_chan, out_chan = test_dataset.in_chan, test_dataset.tgt_chan
|
||||||
|
|
||||||
model = getattr(models, args.model)
|
model = getattr(models, args.model)
|
||||||
model = model(sum(in_chan) + args.noise_chan, sum(out_chan))
|
model = model(sum(in_chan), sum(out_chan))
|
||||||
criterion = getattr(torch.nn, args.criterion)
|
criterion = getattr(torch.nn, args.criterion)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
|
|
||||||
|
@ -66,7 +66,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
noise_chan=args.noise_chan,
|
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -96,7 +95,6 @@ def gpu_worker(local_rank, node, args):
|
|||||||
crop=args.crop,
|
crop=args.crop,
|
||||||
pad=args.pad,
|
pad=args.pad,
|
||||||
scale_factor=args.scale_factor,
|
scale_factor=args.scale_factor,
|
||||||
noise_chan=args.noise_chan,
|
|
||||||
cache=args.cache,
|
cache=args.cache,
|
||||||
div_data=args.div_data,
|
div_data=args.div_data,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -119,7 +117,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
args.in_chan, args.out_chan = train_dataset.in_chan, train_dataset.tgt_chan
|
||||||
|
|
||||||
model = getattr(models, args.model)
|
model = getattr(models, args.model)
|
||||||
model = model(sum(args.in_chan) + args.noise_chan, sum(args.out_chan))
|
model = model(sum(args.in_chan), sum(args.out_chan))
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model = DistributedDataParallel(model, device_ids=[device],
|
model = DistributedDataParallel(model, device_ids=[device],
|
||||||
process_group=dist.new_group())
|
process_group=dist.new_group())
|
||||||
@ -306,8 +304,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
target = narrow_like(target, output) # FIXME pad
|
target = narrow_like(target, output) # FIXME pad
|
||||||
if args.noise_chan > 0:
|
|
||||||
input = input[:, :-args.noise_chan] # remove noise channels
|
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
@ -450,8 +446,6 @@ def validate(epoch, loader, model, criterion, adv_model, adv_criterion,
|
|||||||
output = model(input)
|
output = model(input)
|
||||||
|
|
||||||
target = narrow_like(target, output) # FIXME pad
|
target = narrow_like(target, output) # FIXME pad
|
||||||
if args.noise_chan > 0:
|
|
||||||
input = input[:, :-args.noise_chan] # remove noise channels
|
|
||||||
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
if hasattr(model, 'scale_factor') and model.scale_factor != 1:
|
||||||
input = F.interpolate(input,
|
input = F.interpolate(input,
|
||||||
scale_factor=model.scale_factor, mode='nearest')
|
scale_factor=model.scale_factor, mode='nearest')
|
||||||
|
Loading…
Reference in New Issue
Block a user