Fix upsampling crop bug
This commit is contained in:
parent
b609797e27
commit
ef0235f97b
@ -172,9 +172,10 @@ def crop(fields, start, crop, pad, scale_factor=1):
|
|||||||
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
|
x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear')
|
||||||
x = x.numpy().squeeze(0)
|
x = x.numpy().squeeze(0)
|
||||||
|
|
||||||
# remove buffer
|
# remove buffer and excess padding
|
||||||
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
|
for d, (N, (p0, p1)) in enumerate(zip(crop, pad)):
|
||||||
begin, end = scale_factor, N + p0 + p1 - scale_factor
|
begin = scale_factor + (scale_factor - 1) * p0
|
||||||
|
end = scale_factor * (N + p0 + 1) + p1
|
||||||
x = x.take(range(begin, end), axis=1 + d)
|
x = x.take(range(begin, end), axis=1 + d)
|
||||||
|
|
||||||
new_fields.append(x)
|
new_fields.append(x)
|
||||||
|
@ -140,6 +140,7 @@ def gpu_worker(local_rank, args):
|
|||||||
if args.load_state:
|
if args.load_state:
|
||||||
state = torch.load(args.load_state, map_location=args.device)
|
state = torch.load(args.load_state, map_location=args.device)
|
||||||
args.start_epoch = state['epoch']
|
args.start_epoch = state['epoch']
|
||||||
|
args.adv_delay += args.start_epoch
|
||||||
model.module.load_state_dict(state['model'])
|
model.module.load_state_dict(state['model'])
|
||||||
optimizer.load_state_dict(state['optimizer'])
|
optimizer.load_state_dict(state['optimizer'])
|
||||||
scheduler.load_state_dict(state['scheduler'])
|
scheduler.load_state_dict(state['scheduler'])
|
||||||
@ -150,6 +151,8 @@ def gpu_worker(local_rank, args):
|
|||||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
min_loss = state['min_loss']
|
min_loss = state['min_loss']
|
||||||
|
if 'adv_model' not in state and args.adv:
|
||||||
|
min_loss = None # restarting with adversary wipes the record
|
||||||
print('checkpoint at epoch {} loaded from {}'.format(
|
print('checkpoint at epoch {} loaded from {}'.format(
|
||||||
state['epoch'], args.load_state))
|
state['epoch'], args.load_state))
|
||||||
del state
|
del state
|
||||||
|
Loading…
Reference in New Issue
Block a user