diff --git a/map2map/data/fields.py b/map2map/data/fields.py index 08cde15..10e05b6 100644 --- a/map2map/data/fields.py +++ b/map2map/data/fields.py @@ -172,9 +172,10 @@ def crop(fields, start, crop, pad, scale_factor=1): x = F.interpolate(x, scale_factor=scale_factor, mode='trilinear') x = x.numpy().squeeze(0) - # remove buffer + # remove buffer and excess padding for d, (N, (p0, p1)) in enumerate(zip(crop, pad)): - begin, end = scale_factor, N + p0 + p1 - scale_factor + begin = scale_factor + (scale_factor - 1) * p0 + end = scale_factor * (N + p0 + 1) + p1 x = x.take(range(begin, end), axis=1 + d) new_fields.append(x) diff --git a/map2map/train.py b/map2map/train.py index 2775911..eb98e7d 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -140,6 +140,7 @@ def gpu_worker(local_rank, args): if args.load_state: state = torch.load(args.load_state, map_location=args.device) args.start_epoch = state['epoch'] + args.adv_delay += args.start_epoch model.module.load_state_dict(state['model']) optimizer.load_state_dict(state['optimizer']) scheduler.load_state_dict(state['scheduler']) @@ -150,6 +151,8 @@ def gpu_worker(local_rank, args): torch.set_rng_state(state['rng'].cpu()) # move rng state back if args.rank == 0: min_loss = state['min_loss'] + if 'adv_model' not in state and args.adv: + min_loss = None # restarting with adversary wipes the record print('checkpoint at epoch {} loaded from {}'.format( state['epoch'], args.load_state)) del state