Fix unstable training by limiting pytorch version to 1.1
This commit is contained in:
parent
437126e296
commit
11c9caa1e2
7 changed files with 18 additions and 27 deletions
|
@ -48,7 +48,7 @@ def test(args):
|
|||
|
||||
loss = criterion(output, target)
|
||||
|
||||
print('sample {} loss: {}'.format(i, loss))
|
||||
print('sample {} loss: {}'.format(i, loss.item()))
|
||||
|
||||
if args.norms is not None:
|
||||
norm = test_dataset.norms[0] # FIXME
|
||||
|
|
|
@ -48,7 +48,8 @@ def gpu_worker(local_rank, args):
|
|||
norms=args.norms,
|
||||
pad_or_crop=args.pad_or_crop,
|
||||
)
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batches,
|
||||
|
@ -65,7 +66,8 @@ def gpu_worker(local_rank, args):
|
|||
norms=args.norms,
|
||||
pad_or_crop=args.pad_or_crop,
|
||||
)
|
||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
val_sampler = DistributedSampler(val_dataset)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batches,
|
||||
|
@ -112,9 +114,9 @@ def gpu_worker(local_rank, args):
|
|||
|
||||
if args.rank == 0:
|
||||
args.logger = SummaryWriter()
|
||||
hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor))
|
||||
else str(v) for k, v in vars(args).items()}
|
||||
args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
|
||||
#hparam = {k: v if isinstance(v, (int, float, str, bool, torch.Tensor))
|
||||
# else str(v) for k, v in vars(args).items()}
|
||||
#args.logger.add_hparams(hparam_dict=hparam, metric_dict={})
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
@ -125,7 +127,8 @@ def gpu_worker(local_rank, args):
|
|||
scheduler.step(val_loss)
|
||||
|
||||
if args.rank == 0:
|
||||
args.logger.close()
|
||||
print(end='', flush=True)
|
||||
args.logger.flush()
|
||||
|
||||
state = {
|
||||
'epoch': epoch + 1,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue