Fix (possibly) multiple DistributedDataParallel reduce interference

This commit is contained in:
Yin Li 2020-01-10 14:39:16 -05:00
parent 15384dc9bd
commit 1b1e0e82fa

View file

@ -86,7 +86,8 @@ def gpu_worker(local_rank, args):
model = getattr(models, args.model)
model = model(in_channels, out_channels)
model.to(args.device)
model = DistributedDataParallel(model, device_ids=[args.device])
model = DistributedDataParallel(model, device_ids=[args.device],
process_group=dist.new_group())
criterion = getattr(torch.nn, args.criterion)
criterion = criterion()
@ -110,7 +111,8 @@ def gpu_worker(local_rank, args):
adv_model = adv_model(in_channels + out_channels
if args.cgan else out_channels, 1)
adv_model.to(args.device)
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device])
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device],
process_group=dist.new_group())
adv_criterion = getattr(torch.nn, args.adv_criterion)
adv_criterion = adv_criterion()
@ -272,10 +274,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
epoch_loss[2] += adv_loss.item()
adv_optimizer.zero_grad()
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
adv_loss.backward()
adv_optimizer.step()