Add new pytorch features with fallback to old
This commit is contained in:
parent
b7b14a37eb
commit
7a0f0bd979
1 changed files with 12 additions and 5 deletions
|
@ -68,8 +68,10 @@ def gpu_worker(local_rank, node, args):
|
|||
world_size=args.world_size,
|
||||
)
|
||||
if not args.div_data:
|
||||
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
try:
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||
except TypeError:
|
||||
train_sampler = DistributedSampler(train_dataset) # old pytorch
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batches,
|
||||
|
@ -96,8 +98,10 @@ def gpu_worker(local_rank, node, args):
|
|||
world_size=args.world_size,
|
||||
)
|
||||
if not args.div_data:
|
||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
val_sampler = DistributedSampler(val_dataset)
|
||||
try:
|
||||
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||
except TypeError:
|
||||
val_sampler = DistributedSampler(val_dataset) # old pytorch
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=args.batches,
|
||||
|
@ -236,7 +240,10 @@ def gpu_worker(local_rank, node, args):
|
|||
adv_scheduler.step(epoch_loss[0])
|
||||
|
||||
if rank == 0:
|
||||
logger.close()
|
||||
try:
|
||||
logger.flush()
|
||||
except AttributeError:
|
||||
logger.close() # old pytorch
|
||||
|
||||
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||
if good and epoch >= args.adv_start:
|
||||
|
|
Loading…
Reference in a new issue