From efabcf500838d6bc7e28378bcd8355bf4e7d0885 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 12 Jun 2022 18:34:58 -0600 Subject: [PATCH] When ema is on CPU, only update every 10 steps. --- codes/trainer/ExtensibleTrainer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 186a84b6..90ed33c2 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -401,12 +401,21 @@ class ExtensibleTrainer(BaseModel): if hasattr(net.module, "after_step"): net.module.after_step(it) if self.do_emas: + # When the EMA is on the CPU, only update every 10 steps to save processing time. + if self.ema_on_cpu and step % 5 != 0: + continue ema_params = self.emas[name].parameters() net_params = net.parameters() for ep, np in zip(ema_params, net_params): + ema_rate = self.ema_rate + new_rate = 1 - ema_rate if self.ema_on_cpu: np = np.cpu() - ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate) + ema_rate = ema_rate ** 10 # Because it only happens every 10 steps. + mid = (1 - (ema_rate+new_rate))/2 + ema_rate += mid + new_rate += mid + ep.detach().mul_(ema_rate).add_(np, alpha=1 - ema_rate) [e.after_optimize(state) for e in self.experiments]