From 7a0f0bd979e84da21ab9f045f4c4159e88acfa56 Mon Sep 17 00:00:00 2001 From: Yin Li Date: Sat, 7 Mar 2020 14:49:18 -0500 Subject: [PATCH] Add new pytorch features with fallback to old --- map2map/train.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/map2map/train.py b/map2map/train.py index ce61f79..177cebf 100644 --- a/map2map/train.py +++ b/map2map/train.py @@ -68,8 +68,10 @@ def gpu_worker(local_rank, node, args): world_size=args.world_size, ) if not args.div_data: - #train_sampler = DistributedSampler(train_dataset, shuffle=True) - train_sampler = DistributedSampler(train_dataset) + try: + train_sampler = DistributedSampler(train_dataset, shuffle=True) + except TypeError: + train_sampler = DistributedSampler(train_dataset) # old pytorch train_loader = DataLoader( train_dataset, batch_size=args.batches, @@ -96,8 +98,10 @@ def gpu_worker(local_rank, node, args): world_size=args.world_size, ) if not args.div_data: - #val_sampler = DistributedSampler(val_dataset, shuffle=False) - val_sampler = DistributedSampler(val_dataset) + try: + val_sampler = DistributedSampler(val_dataset, shuffle=False) + except TypeError: + val_sampler = DistributedSampler(val_dataset) # old pytorch val_loader = DataLoader( val_dataset, batch_size=args.batches, @@ -236,7 +240,10 @@ def gpu_worker(local_rank, node, args): adv_scheduler.step(epoch_loss[0]) if rank == 0: - logger.close() + try: + logger.flush() + except AttributeError: + logger.close() # old pytorch good = min_loss is None or epoch_loss[0] < min_loss[0] if good and epoch >= args.adv_start: