Fix adversary model optim scheduler
This commit is contained in:
parent
1b1e0e82fa
commit
e7d2435a96
@ -12,7 +12,7 @@ class FieldDataset(Dataset):
|
|||||||
`in_patterns` is a list of glob patterns for the input fields.
|
`in_patterns` is a list of glob patterns for the input fields.
|
||||||
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
For example, `in_patterns=['/train/field1_*.npy', '/train/field2_*.npy']`.
|
||||||
Likewise `tgt_patterns` is for target fields.
|
Likewise `tgt_patterns` is for target fields.
|
||||||
Input and target samples of all fields are matched by sorting the globbed files.
|
Input and target samples are matched by sorting the globbed files.
|
||||||
|
|
||||||
`norms` can be a list of callables to normalize each field.
|
`norms` can be a list of callables to normalize each field.
|
||||||
|
|
||||||
|
@ -32,12 +32,13 @@ def test(args):
|
|||||||
|
|
||||||
device = torch.device('cpu')
|
device = torch.device('cpu')
|
||||||
state = torch.load(args.load_state, map_location=device)
|
state = torch.load(args.load_state, map_location=device)
|
||||||
from collections import OrderedDict
|
# from collections import OrderedDict
|
||||||
model_state = OrderedDict()
|
# model_state = OrderedDict()
|
||||||
for k, v in state['model'].items():
|
# for k, v in state['model'].items():
|
||||||
model_k = k.replace('module.', '', 1) # FIXME
|
# model_k = k.replace('module.', '', 1) # FIXME
|
||||||
model_state[model_k] = v
|
# model_state[model_k] = v
|
||||||
model.load_state_dict(model_state)
|
# model.load_state_dict(model_state)
|
||||||
|
model.load_state_dict(state['model'])
|
||||||
print('model state at epoch {} loaded from {}'.format(
|
print('model state at epoch {} loaded from {}'.format(
|
||||||
state['epoch'], args.load_state))
|
state['epoch'], args.load_state))
|
||||||
del state
|
del state
|
||||||
|
@ -130,8 +130,7 @@ def gpu_worker(local_rank, args):
|
|||||||
betas=(0.5, 0.999),
|
betas=(0.5, 0.999),
|
||||||
weight_decay=args.adv_weight_decay,
|
weight_decay=args.adv_weight_decay,
|
||||||
)
|
)
|
||||||
adv_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(adv_optimizer,
|
adv_scheduler = torch.optim.lr_scheduler.StepLR(adv_optimizer, 30, gamma=0.1)
|
||||||
factor=0.5, patience=3, verbose=True)
|
|
||||||
|
|
||||||
if args.load_state:
|
if args.load_state:
|
||||||
state = torch.load(args.load_state, map_location=args.device)
|
state = torch.load(args.load_state, map_location=args.device)
|
||||||
@ -185,6 +184,8 @@ def gpu_worker(local_rank, args):
|
|||||||
epoch_loss = val_loss
|
epoch_loss = val_loss
|
||||||
|
|
||||||
scheduler.step(epoch_loss[0])
|
scheduler.step(epoch_loss[0])
|
||||||
|
if args.adv:
|
||||||
|
adv_scheduler.step()
|
||||||
|
|
||||||
if args.rank == 0:
|
if args.rank == 0:
|
||||||
print(end='', flush=True)
|
print(end='', flush=True)
|
||||||
@ -274,7 +275,6 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
epoch_loss[4] += adv_loss_real.item()
|
epoch_loss[4] += adv_loss_real.item()
|
||||||
|
|
||||||
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
adv_loss = 0.5 * (adv_loss_fake + adv_loss_real)
|
||||||
adv_loss = 0.001 * adv_loss + 0.999 * adv_loss.item()
|
|
||||||
epoch_loss[2] += adv_loss.item()
|
epoch_loss[2] += adv_loss.item()
|
||||||
|
|
||||||
adv_optimizer.zero_grad()
|
adv_optimizer.zero_grad()
|
||||||
@ -287,9 +287,8 @@ def train(epoch, loader, model, criterion, optimizer, scheduler,
|
|||||||
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
loss_adv = adv_criterion(eval_out, real) # FIXME try min
|
||||||
epoch_loss[1] += loss_adv.item()
|
epoch_loss[1] += loss_adv.item()
|
||||||
|
|
||||||
# loss_fac = loss.item() / (loss.item() + loss_adv.item())
|
loss_fac = loss.item() / (loss_adv.item() + 1e-8)
|
||||||
# loss = 0.5 * (loss * (1 + loss_fac) + loss_adv * loss_fac) # FIXME does this work?
|
loss += loss_fac * (loss_adv - loss_adv.item()) # FIXME does this work?
|
||||||
loss += 0.001 * (loss_adv - loss_adv.item())
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
Loading…
Reference in New Issue
Block a user