Fix bug
This commit is contained in:
parent
cd63324724
commit
890a459363
@ -71,10 +71,10 @@ def add_train_args(parser):
|
||||
help='enable minimum reduction in adversarial criterion')
|
||||
parser.add_argument('--cgan', action='store_true',
|
||||
help='enable conditional GAN')
|
||||
parser.add_argument('--loss-halflife', default=10, type=float,
|
||||
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||
help='final fraction of loss (vs adv-loss)')
|
||||
parser.add_argument('--loss-halflife', default=10, type=float,
|
||||
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||
|
||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||
help='optimizer from torch.optim')
|
||||
|
@ -64,6 +64,7 @@ def gpu_worker(local_rank, node, args):
|
||||
train_dataset = FieldDataset(
|
||||
in_patterns=args.train_in_patterns,
|
||||
tgt_patterns=args.train_tgt_patterns,
|
||||
rank=rank,
|
||||
**vars(args),
|
||||
)
|
||||
if not args.div_data:
|
||||
@ -83,6 +84,7 @@ def gpu_worker(local_rank, node, args):
|
||||
in_patterns=args.val_in_patterns,
|
||||
tgt_patterns=args.val_tgt_patterns,
|
||||
augment=False,
|
||||
rank=rank,
|
||||
**{k: v for k, v in vars(args).items() if k != 'augment'},
|
||||
)
|
||||
if not args.div_data:
|
||||
@ -190,6 +192,7 @@ def gpu_worker(local_rank, node, args):
|
||||
|
||||
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
||||
|
||||
logger = None
|
||||
if rank == 0:
|
||||
logger = SummaryWriter()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user