diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py
index 186a84b6..90ed33c2 100644
--- a/codes/trainer/ExtensibleTrainer.py
+++ b/codes/trainer/ExtensibleTrainer.py
@@ -401,12 +401,21 @@ class ExtensibleTrainer(BaseModel):
             if hasattr(net.module, "after_step"):
                 net.module.after_step(it)
             if self.do_emas:
+                # When the EMA is on the CPU, only update every 10 steps to save processing time.
+                if self.ema_on_cpu and step % 5 != 0:
+                    continue
                 ema_params = self.emas[name].parameters()
                 net_params = net.parameters()
                 for ep, np in zip(ema_params, net_params):
+                    ema_rate = self.ema_rate
+                    new_rate = 1 - ema_rate
                     if self.ema_on_cpu:
                         np = np.cpu()
-                    ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
+                        ema_rate = ema_rate ** 10  # Because it only happens every 10 steps.
+                        mid = (1 - (ema_rate+new_rate))/2
+                        ema_rate += mid
+                        new_rate += mid
+                    ep.detach().mul_(ema_rate).add_(np, alpha=1 - ema_rate)
         [e.after_optimize(state) for e in self.experiments]