Fix bug
This commit is contained in:
parent
cd63324724
commit
890a459363
@ -71,10 +71,10 @@ def add_train_args(parser):
|
|||||||
help='enable minimum reduction in adversarial criterion')
|
help='enable minimum reduction in adversarial criterion')
|
||||||
parser.add_argument('--cgan', action='store_true',
|
parser.add_argument('--cgan', action='store_true',
|
||||||
help='enable conditional GAN')
|
help='enable conditional GAN')
|
||||||
parser.add_argument('--loss-halflife', default=10, type=float,
|
|
||||||
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
|
||||||
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
parser.add_argument('--loss-fraction', default=0.5, type=float,
|
||||||
help='final fraction of loss (vs adv-loss)')
|
help='final fraction of loss (vs adv-loss)')
|
||||||
|
parser.add_argument('--loss-halflife', default=10, type=float,
|
||||||
|
help='half-life (epoch) to anneal loss while enhancing adv-loss')
|
||||||
|
|
||||||
parser.add_argument('--optimizer', default='Adam', type=str,
|
parser.add_argument('--optimizer', default='Adam', type=str,
|
||||||
help='optimizer from torch.optim')
|
help='optimizer from torch.optim')
|
||||||
|
@ -64,6 +64,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
train_dataset = FieldDataset(
|
train_dataset = FieldDataset(
|
||||||
in_patterns=args.train_in_patterns,
|
in_patterns=args.train_in_patterns,
|
||||||
tgt_patterns=args.train_tgt_patterns,
|
tgt_patterns=args.train_tgt_patterns,
|
||||||
|
rank=rank,
|
||||||
**vars(args),
|
**vars(args),
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
@ -83,6 +84,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
in_patterns=args.val_in_patterns,
|
in_patterns=args.val_in_patterns,
|
||||||
tgt_patterns=args.val_tgt_patterns,
|
tgt_patterns=args.val_tgt_patterns,
|
||||||
augment=False,
|
augment=False,
|
||||||
|
rank=rank,
|
||||||
**{k: v for k, v in vars(args).items() if k != 'augment'},
|
**{k: v for k, v in vars(args).items() if k != 'augment'},
|
||||||
)
|
)
|
||||||
if not args.div_data:
|
if not args.div_data:
|
||||||
@ -190,6 +192,7 @@ def gpu_worker(local_rank, node, args):
|
|||||||
|
|
||||||
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
torch.backends.cudnn.benchmark = True # NOTE: test perf
|
||||||
|
|
||||||
|
logger = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
logger = SummaryWriter()
|
logger = SummaryWriter()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user