Allow EMA training to be disabled

This commit is contained in:
James Betker 2022-02-12 20:00:23 -07:00
parent 3252972057
commit 15fd60aad3

View File

@ -54,6 +54,7 @@ class ExtensibleTrainer(BaseModel):
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
# It is advantageous for large networks to do this to save an extra copy of the model weights.
# It does come at the cost of a round trip to CPU memory at every batch.
self.do_emas = opt_get(train_opt, ['ema_enabled'], True)
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
self.checkpointing_cache = opt['checkpointing_enabled']
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
@ -156,7 +157,7 @@ class ExtensibleTrainer(BaseModel):
if v == dnet.module:
net_dict[k] = dnet
self.networks[k] = dnet
if self.is_train:
if self.is_train and self.do_emas:
self.emas[k] = copy.deepcopy(v)
if self.ema_on_cpu:
self.emas[k] = self.emas[k].cpu()
@ -357,18 +358,20 @@ class ExtensibleTrainer(BaseModel):
else:
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
if self.do_emas:
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
# Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items():
if hasattr(net, "custom_optimizer_step"):
net.custom_optimizer_step(it)
ema_params = self.emas[name].parameters()
net_params = net.parameters()
for ep, np in zip(ema_params, net_params):
if self.ema_on_cpu:
np = np.cpu()
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
if self.do_emas:
ema_params = self.emas[name].parameters()
net_params = net.parameters()
for ep, np in zip(ema_params, net_params):
if self.ema_on_cpu:
np = np.cpu()
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
[e.after_optimize(state) for e in self.experiments]
@ -455,7 +458,7 @@ class ExtensibleTrainer(BaseModel):
logger.info('Loading model for [%s]' % (load_path,))
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
load_path_ema = load_path.replace('.pth', '_ema.pth')
if self.is_train:
if self.is_train and self.do_emas:
ema_model = self.emas[name]
if os.path.exists(load_path_ema):
self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
@ -472,7 +475,8 @@ class ExtensibleTrainer(BaseModel):
# Don't save non-trainable networks.
if self.opt['networks'][name]['trainable']:
self.save_network(net, name, iter_step)
self.save_network(self.emas[name], f'{name}_ema', iter_step)
if self.do_emas:
self.save_network(self.emas[name], f'{name}_ema', iter_step)
def force_restore_swapout(self):
# Legacy method. Do nothing.