Add new pytorch features with fallback to old

This commit is contained in:
Yin Li 2020-03-07 14:49:18 -05:00
parent b7b14a37eb
commit 7a0f0bd979

View File

@ -68,8 +68,10 @@ def gpu_worker(local_rank, node, args):
world_size=args.world_size,
)
if not args.div_data:
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_sampler = DistributedSampler(train_dataset)
try:
train_sampler = DistributedSampler(train_dataset, shuffle=True)
except TypeError:
train_sampler = DistributedSampler(train_dataset) # old pytorch
train_loader = DataLoader(
train_dataset,
batch_size=args.batches,
@ -96,8 +98,10 @@ def gpu_worker(local_rank, node, args):
world_size=args.world_size,
)
if not args.div_data:
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
val_sampler = DistributedSampler(val_dataset)
try:
val_sampler = DistributedSampler(val_dataset, shuffle=False)
except TypeError:
val_sampler = DistributedSampler(val_dataset) # old pytorch
val_loader = DataLoader(
val_dataset,
batch_size=args.batches,
@ -236,7 +240,10 @@ def gpu_worker(local_rank, node, args):
adv_scheduler.step(epoch_loss[0])
if rank == 0:
logger.close()
try:
logger.flush()
except AttributeError:
logger.close() # old pytorch
good = min_loss is None or epoch_loss[0] < min_loss[0]
if good and epoch >= args.adv_start: