Fix adversary model optim scheduler
This commit is contained in:
parent
1b1e0e82fa
commit
e7d2435a96
@ -12,7 +12,7 @@ class FieldDataset(Dataset):
|
||||
`in_patterns` is a list of glob patterns for the input fields.
|
||||
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
||||
Likewise `tgt_patterns` is for target fields.
|
||||
Input and target samples of all fields are matched by sorting the globbed files.
|
||||
Input and target samples are matched by sorting the globbed files.
|
||||
|
||||
`norms` can be a list of callables to normalize each field.
|
||||
|
||||
|
@ -32,12 +32,13 @@ def test(args):
|
||||
|
||||
device = torch.device('cpu')
|
||||
state = torch.load(args.load_state, map_location=device)
|
||||
from collections import OrderedDict
|
||||
model_state = OrderedDict()
|
||||
for k, v in state['model'].items():
|
||||
model_k = k.replace('module.', '', 1) # FIXME
|
||||
model_state[model_k] = v
|
||||
model.load_state_dict(model_state)
|
||||
# from collections import OrderedDict
|
||||
# model_state = OrderedDict()
|
||||
# for k, v in state['model'].items():
|
||||
# model_k = k.replace('module.', '', 1) # FIXME
|
||||
# model_state[model_k] = v
|
||||
# model.load_state_dict(model_state)
|
||||
model.load_state_dict(state['model'])
|
||||
print('model state at epoch {} loaded from {}'.format(
|
||||
state['epoch'], args.load_state))
|
||||
del state
|
||||
|
@ -130,8 +130,7 @@ def gpu_worker(local_rank, args):
|
||||
betas=(0.5, 0.999),
|
||||
weight_decay=args.adv_weight_decay,
|
||||
)
|
||||
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
||||
factor=0.5, patience=3, verbose=True)
|
||||
adv_scheduler = torch.optim.lr_scheduler.StepLR(adv_optimizer, 30, gamma=0.1)
|
||||
|
||||
if args.load_state:
|
||||
state = torch.load(args.load_state, map_location=args.device)
|
||||
@ -185,6 +184,8 @@ def gpu_worker(local_rank, args):
|
||||
epoch_loss = val_loss
|
||||
|
||||
scheduler.step(epoch_loss[0])
|
||||
if args.adv:
|
||||
adv_scheduler.step()
|
||||
|
||||
if args.rank == 0:
|
||||
print(end='', flush=True)
|
||||
@ -274,7 +275,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
epoch_loss[4] += adv_loss_real.item()
|
||||
|
||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
|
||||
epoch_loss[2] += adv_loss.item()
|
||||
|
||||
adv_optimizer.zero_grad()
|
||||
@ -287,9 +287,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||
epoch_loss[1] += loss_adv.item()
|
||||
|
||||
# loss_fac = loss.item() / (loss.item() + loss_adv.item())
|
||||
# loss = 0.5 * (loss * (1 + loss_fac) + loss_adv * loss_fac) # FIXME does this work?
|
||||
loss += 0.001 * (loss_adv - loss_adv.item())
|
||||
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
|
||||
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
Loading…
Reference in New Issue
Block a user