Fix adversary model optim scheduler

This commit is contained in:
Yin Li 2020-01-19 19:10:14 -05:00
parent 1b1e0e82fa
commit e7d2435a96
3 changed files with 13 additions and 13 deletions

View File

@ -12,7 +12,7 @@ class FieldDataset(Dataset):
`in_patterns` is a list of glob patterns for the input fields. `in_patterns` is a list of glob patterns for the input fields.
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`. For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
Likewise `tgt_patterns` is for target fields. Likewise `tgt_patterns` is for target fields.
Input and target samples of all fields are matched by sorting the globbed files. Input and target samples are matched by sorting the globbed files.
`norms` can be a list of callables to normalize each field. `norms` can be a list of callables to normalize each field.

View File

@ -32,12 +32,13 @@ def test(args):
device = torch.device('cpu') device = torch.device('cpu')
state = torch.load(args.load_state, map_location=device) state = torch.load(args.load_state, map_location=device)
from collections import OrderedDict # from collections import OrderedDict
model_state = OrderedDict() # model_state = OrderedDict()
for k, v in state['model'].items(): # for k, v in state['model'].items():
model_k = k.replace('module.', '', 1) # FIXME # model_k = k.replace('module.', '', 1) # FIXME
model_state[model_k] = v # model_state[model_k] = v
model.load_state_dict(model_state) # model.load_state_dict(model_state)
model.load_state_dict(state['model'])
print('model state at epoch {} loaded from {}'.format( print('model state at epoch {} loaded from {}'.format(
state['epoch'], args.load_state)) state['epoch'], args.load_state))
del state del state

View File

@ -130,8 +130,7 @@ def gpu_worker(local_rank, args):
betas=(0.5, 0.999), betas=(0.5, 0.999),
weight_decay=args.adv_weight_decay, weight_decay=args.adv_weight_decay,
) )
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer, adv_scheduler = torch.optim.lr_scheduler.StepLR(adv_optimizer, 30, gamma=0.1)
factor=0.5, patience=3, verbose=True)
if args.load_state: if args.load_state:
state = torch.load(args.load_state, map_location=args.device) state = torch.load(args.load_state, map_location=args.device)
@ -185,6 +184,8 @@ def gpu_worker(local_rank, args):
epoch_loss = val_loss epoch_loss = val_loss
scheduler.step(epoch_loss[0]) scheduler.step(epoch_loss[0])
if args.adv:
adv_scheduler.step()
if args.rank == 0: if args.rank == 0:
print(end='', flush=True) print(end='', flush=True)
@ -274,7 +275,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
epoch_loss[4] += adv_loss_real.item() epoch_loss[4] += adv_loss_real.item()
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real) adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
epoch_loss[2] += adv_loss.item() epoch_loss[2] += adv_loss.item()
adv_optimizer.zero_grad() adv_optimizer.zero_grad()
@ -287,9 +287,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
loss_adv = adv_criterion(eval_out, real) # FIXME try min loss_adv = adv_criterion(eval_out, real) # FIXME try min
epoch_loss[1] += loss_adv.item() epoch_loss[1] += loss_adv.item()
# loss_fac = loss.item() / (loss.item() + loss_adv.item()) loss_fac = loss.item() / (loss_adv.item() + 1e-8)
# loss = 0.5 * (loss * (1 + loss_fac) + loss_adv * loss_fac) # FIXME does this work? loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
loss += 0.001 * (loss_adv - loss_adv.item())
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()