Fix (possibly) multiple DistributedDataParallel reduce interference
This commit is contained in:
parent
15384dc9bd
commit
1b1e0e82fa
@ -86,7 +86,8 @@ def gpu_worker(local_rank, args):
|
|||||||
model = getattr(models, args.model)
|
model = getattr(models, args.model)
|
||||||
model = model(in_channels, out_channels)
|
model = model(in_channels, out_channels)
|
||||||
model.to(args.device)
|
model.to(args.device)
|
||||||
model = DistributedDataParallel(model, device_ids=[args.device])
|
model = DistributedDataParallel(model, device_ids=[args.device],
|
||||||
|
process_group=dist.new_group())
|
||||||
|
|
||||||
criterion = getattr(torch.nn, args.criterion)
|
criterion = getattr(torch.nn, args.criterion)
|
||||||
criterion = criterion()
|
criterion = criterion()
|
||||||
@ -110,7 +111,8 @@ def gpu_worker(local_rank, args):
|
|||||||
adv_model = adv_model(in_channels + out_channels
|
adv_model = adv_model(in_channels + out_channels
|
||||||
if args.cgan else out_channels, 1)
|
if args.cgan else out_channels, 1)
|
||||||
adv_model.to(args.device)
|
adv_model.to(args.device)
|
||||||
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device])
|
adv_model = DistributedDataParallel(adv_model, device_ids=[args.device],
|
||||||
|
process_group=dist.new_group())
|
||||||
|
|
||||||
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
adv_criterion = getattr(torch.nn, args.adv_criterion)
|
||||||
adv_criterion = adv_criterion()
|
adv_criterion = adv_criterion()
|
||||||
@ -272,10 +274,10 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
epoch_loss[4] += adv_loss_real.item()
|
epoch_loss[4] += adv_loss_real.item()
|
||||||
|
|
||||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||||
|
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
|
||||||
epoch_loss[2] += adv_loss.item()
|
epoch_loss[2] += adv_loss.item()
|
||||||
|
|
||||||
adv_optimizer.zero_grad()
|
adv_optimizer.zero_grad()
|
||||||
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
|
|
||||||
adv_loss.backward()
|
adv_loss.backward()
|
||||||
adv_optimizer.step()
|
adv_optimizer.step()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user