Fix (possibly) multiple DistributedDataParallel reduce interference
This commit is contained in:
parent
15384dc9bd
commit
1b1e0e82fa
@ -86,7 +86,8 @@ def gpu_worker(local_rank, args):
|
||||
model = getattr(models, args.model)
|
||||
model = model(in_channels, out_channels)
|
||||
model.to(args.device)
|
||||
model = DistributedDataParallel(model, device_ids=[args.device])
|
||||
model = DistributedDataParallel(model, device_ids=[args.device],
|
||||
process_group=dist.new_group())
|
||||
|
||||
criterion = getattr(torch.nn, args.criterion)
|
||||
criterion = criterion()
|
||||
@ -110,7 +111,8 @@ def gpu_worker(local_rank, args):
|
||||
adv_model = adv_model(in_channels + out_channels
|
||||
if args.cgan else out_channels, 1)
|
||||
adv_model.to(args.device)
|
||||
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device])
|
||||
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device],
|
||||
process_group=dist.new_group())
|
||||
|
||||
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
||||
adv_criterion = adv_criterion()
|
||||
@ -272,10 +274,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
adv_optimizer.zero_grad()
|
||||
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
|
||||
adv_loss.backward()
|
||||
adv_optimizer.step()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user