Add new pytorch features with fallback to old
This commit is contained in:
parent
b7b14a37eb
commit
7a0f0bd979
@ -68,8 +68,10 @@ def gpu_worker(local_rank, node, args):
|
|||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
#train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
try:
|
||||||
train_sampler = DistributedSampler(train_dataset)
|
train_sampler = DistributedSampler(train_dataset, shuffle=True)
|
||||||
|
except TypeError:
|
||||||
|
train_sampler = DistributedSampler(train_dataset) # old pytorch
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -96,8 +98,10 @@ def gpu_worker(local_rank, node, args):
|
|||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
#val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
try:
|
||||||
val_sampler = DistributedSampler(val_dataset)
|
val_sampler = DistributedSampler(val_dataset, shuffle=False)
|
||||||
|
except TypeError:
|
||||||
|
val_sampler = DistributedSampler(val_dataset) # old pytorch
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=args.batches,
|
batch_size=args.batches,
|
||||||
@ -236,7 +240,10 @@ def gpu_worker(local_rank, node, args):
|
|||||||
adv_scheduler.step(epoch_loss[0])
|
adv_scheduler.step(epoch_loss[0])
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger.close()
|
try:
|
||||||
|
logger.flush()
|
||||||
|
except AttributeError:
|
||||||
|
logger.close() # old pytorch
|
||||||
|
|
||||||
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
good = min_loss is None or epoch_loss[0] < min_loss[0]
|
||||||
if good and epoch >= args.adv_start:
|
if good and epoch >= args.adv_start:
|
||||||
|
Loading…
Reference in New Issue
Block a user