forked from mrq/DL-Art-School
When ema is on CPU, only update every 10 steps.
This commit is contained in:
parent
fc3a7ed5e3
commit
efabcf5008
|
@ -401,12 +401,21 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if hasattr(net.module, "after_step"):
|
if hasattr(net.module, "after_step"):
|
||||||
net.module.after_step(it)
|
net.module.after_step(it)
|
||||||
if self.do_emas:
|
if self.do_emas:
|
||||||
|
# When the EMA is on the CPU, only update every 10 steps to save processing time.
|
||||||
|
if self.ema_on_cpu and step % 5 != 0:
|
||||||
|
continue
|
||||||
ema_params = self.emas[name].parameters()
|
ema_params = self.emas[name].parameters()
|
||||||
net_params = net.parameters()
|
net_params = net.parameters()
|
||||||
for ep, np in zip(ema_params, net_params):
|
for ep, np in zip(ema_params, net_params):
|
||||||
|
ema_rate = self.ema_rate
|
||||||
|
new_rate = 1 - ema_rate
|
||||||
if self.ema_on_cpu:
|
if self.ema_on_cpu:
|
||||||
np = np.cpu()
|
np = np.cpu()
|
||||||
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
ema_rate = ema_rate ** 10 # Because it only happens every 10 steps.
|
||||||
|
mid = (1 - (ema_rate+new_rate))/2
|
||||||
|
ema_rate += mid
|
||||||
|
new_rate += mid
|
||||||
|
ep.detach().mul_(ema_rate).add_(np, alpha=1 - ema_rate)
|
||||||
[e.after_optimize(state) for e in self.experiments]
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user