Add new pytorch features with fallback to old

This commit is contained in:
Yin Li 2020-03-07 14:49:18 -05:00
parent b7b14a37eb
commit 7a0f0bd979

View File

@ -68,8 +68,10 @@ def gpu_worker(local_rank, node, args):
world_size=args.world_size, world_size=args.world_size,
) )
if not args.div_data: if not args.div_data:
#train_sampler = DistributedSampler(train_dataset, shuffle=True) try:
train_sampler = DistributedSampler(train_dataset) train_sampler = DistributedSampler(train_dataset, shuffle=True)
except TypeError:
train_sampler = DistributedSampler(train_dataset) # old pytorch
train_loader = DataLoader( train_loader = DataLoader(
train_dataset, train_dataset,
batch_size=args.batches, batch_size=args.batches,
@ -96,8 +98,10 @@ def gpu_worker(local_rank, node, args):
world_size=args.world_size, world_size=args.world_size,
) )
if not args.div_data: if not args.div_data:
#val_sampler = DistributedSampler(val_dataset, shuffle=False) try:
val_sampler = DistributedSampler(val_dataset) val_sampler = DistributedSampler(val_dataset, shuffle=False)
except TypeError:
val_sampler = DistributedSampler(val_dataset) # old pytorch
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=args.batches, batch_size=args.batches,
@ -236,7 +240,10 @@ def gpu_worker(local_rank, node, args):
adv_scheduler.step(epoch_loss[0]) adv_scheduler.step(epoch_loss[0])
if rank == 0: if rank == 0:
logger.close() try:
logger.flush()
except AttributeError:
logger.close() # old pytorch
good = min_loss is None or epoch_loss[0] < min_loss[0] good = min_loss is None or epoch_loss[0] < min_loss[0]
if good and epoch >= args.adv_start: if good and epoch >= args.adv_start: