Fix DistributedDataParallel model save and load during training, leave testing for later
This commit is contained in:
parent
f2e9af6d5f
commit
437126e296
@ -94,7 +94,7 @@ def gpu_worker(local_rank, args):
|
|||||||
if args.load_state:
|
if args.load_state:
|
||||||
state = torch.load(args.load_state, map_location=args.device)
|
state = torch.load(args.load_state, map_location=args.device)
|
||||||
args.start_epoch = state['epoch']
|
args.start_epoch = state['epoch']
|
||||||
model.load_state_dict(state['model'])
|
model.module.load_state_dict(state['model'])
|
||||||
optimizer.load_state_dict(state['optimizer'])
|
optimizer.load_state_dict(state['optimizer'])
|
||||||
scheduler.load_state_dict(state['scheduler'])
|
scheduler.load_state_dict(state['scheduler'])
|
||||||
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
torch.set_rng_state(state['rng'].cpu()) # move rng state back
|
||||||
@ -129,7 +129,7 @@ def gpu_worker(local_rank, args):
|
|||||||
|
|
||||||
state = {
|
state = {
|
||||||
'epoch': epoch + 1,
|
'epoch': epoch + 1,
|
||||||
'model': model.state_dict(),
|
'model': model.module.state_dict(),
|
||||||
'optimizer' : optimizer.state_dict(),
|
'optimizer' : optimizer.state_dict(),
|
||||||
'scheduler' : scheduler.state_dict(),
|
'scheduler' : scheduler.state_dict(),
|
||||||
'rng' : torch.get_rng_state(),
|
'rng' : torch.get_rng_state(),
|
||||||
|
Loading…
Reference in New Issue
Block a user