diff --git a/map2map/models/instance_noise.py b/map2map/models/instance_noise.py index 8fd1dbc..02dd77d 100644 --- a/map2map/models/instance_noise.py +++ b/map2map/models/instance_noise.py @@ -6,13 +6,12 @@ class InstanceNoise: """ def __init__(self, init_std): self.init_std = init_std - self.adv_loss_cum = 0 + self.anneal = 1 self.ln2 = log(2) self.batches = 1e5 def std(self, adv_loss): - self.adv_loss_cum += adv_loss - anneal = 1 - self.adv_loss_cum / self.ln2 / self.batches - anneal = anneal if anneal > 0 else 0 - std = anneal * self.init_std + self.anneal -= adv_loss / self.ln2 / self.batches + std = self.anneal * self.init_std + std = std if std > 0 else 0 return std