This commit is contained in:
Yin Li 2020-02-06 19:34:43 -06:00
parent cd63324724
commit 890a459363
2 changed files with 5 additions and 2 deletions

View File

@ -71,10 +71,10 @@ def add_train_args(parser):
help='enable minimum reduction in adversarial criterion') help='enable minimum reduction in adversarial criterion')
parser.add_argument('--cgan', action='store_true', parser.add_argument('--cgan', action='store_true',
help='enable conditional GAN') help='enable conditional GAN')
parser.add_argument('--loss-halflife', default=10, type=float,
help='half-life (epoch) to anneal loss while enhancing adv-loss')
parser.add_argument('--loss-fraction', default=0.5, type=float, parser.add_argument('--loss-fraction', default=0.5, type=float,
help='final fraction of loss (vs adv-loss)') help='final fraction of loss (vs adv-loss)')
parser.add_argument('--loss-halflife', default=10, type=float,
help='half-life (epoch) to anneal loss while enhancing adv-loss')
parser.add_argument('--optimizer', default='Adam', type=str, parser.add_argument('--optimizer', default='Adam', type=str,
help='optimizer from torch.optim') help='optimizer from torch.optim')

View File

@ -64,6 +64,7 @@ def gpu_worker(local_rank, node, args):
train_dataset = FieldDataset( train_dataset = FieldDataset(
in_patterns=args.train_in_patterns, in_patterns=args.train_in_patterns,
tgt_patterns=args.train_tgt_patterns, tgt_patterns=args.train_tgt_patterns,
rank=rank,
**vars(args), **vars(args),
) )
if not args.div_data: if not args.div_data:
@ -83,6 +84,7 @@ def gpu_worker(local_rank, node, args):
in_patterns=args.val_in_patterns, in_patterns=args.val_in_patterns,
tgt_patterns=args.val_tgt_patterns, tgt_patterns=args.val_tgt_patterns,
augment=False, augment=False,
rank=rank,
**{k: v for k, v in vars(args).items() if k != 'augment'}, **{k: v for k, v in vars(args).items() if k != 'augment'},
) )
if not args.div_data: if not args.div_data:
@ -190,6 +192,7 @@ def gpu_worker(local_rank, node, args):
torch.backends.cudnn.benchmark = True # NOTE: test perf torch.backends.cudnn.benchmark = True # NOTE: test perf
logger = None
if rank == 0: if rank == 0:
logger = SummaryWriter() logger = SummaryWriter()