Allow EMA training to be disabled

This commit is contained in:
James Betker 2022-02-12 20:00:23 -07:00
parent 3252972057
commit 15fd60aad3

View File

@ -54,6 +54,7 @@ class ExtensibleTrainer(BaseModel):
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999) self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
# It is advantageous for large networks to do this to save an extra copy of the model weights. # It is advantageous for large networks to do this to save an extra copy of the model weights.
# It does come at the cost of a round trip to CPU memory at every batch. # It does come at the cost of a round trip to CPU memory at every batch.
self.do_emas = opt_get(train_opt, ['ema_enabled'], True)
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False) self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
self.checkpointing_cache = opt['checkpointing_enabled'] self.checkpointing_cache = opt['checkpointing_enabled']
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None) self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
@ -156,7 +157,7 @@ class ExtensibleTrainer(BaseModel):
if v == dnet.module: if v == dnet.module:
net_dict[k] = dnet net_dict[k] = dnet
self.networks[k] = dnet self.networks[k] = dnet
if self.is_train: if self.is_train and self.do_emas:
self.emas[k] = copy.deepcopy(v) self.emas[k] = copy.deepcopy(v)
if self.ema_on_cpu: if self.ema_on_cpu:
self.emas[k] = self.emas[k].cpu() self.emas[k] = self.emas[k].cpu()
@ -357,18 +358,20 @@ class ExtensibleTrainer(BaseModel):
else: else:
if k in self.networks.keys(): # This isn't always the case, for example for EMAs. if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True) self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True) if self.do_emas:
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
# Call into custom step hooks as well as update EMA params. # Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items(): for name, net in self.networks.items():
if hasattr(net, "custom_optimizer_step"): if hasattr(net, "custom_optimizer_step"):
net.custom_optimizer_step(it) net.custom_optimizer_step(it)
ema_params = self.emas[name].parameters() if self.do_emas:
net_params = net.parameters() ema_params = self.emas[name].parameters()
for ep, np in zip(ema_params, net_params): net_params = net.parameters()
if self.ema_on_cpu: for ep, np in zip(ema_params, net_params):
np = np.cpu() if self.ema_on_cpu:
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate) np = np.cpu()
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
[e.after_optimize(state) for e in self.experiments] [e.after_optimize(state) for e in self.experiments]
@ -455,7 +458,7 @@ class ExtensibleTrainer(BaseModel):
logger.info('Loading model for [%s]' % (load_path,)) logger.info('Loading model for [%s]' % (load_path,))
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}'])) self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
load_path_ema = load_path.replace('.pth', '_ema.pth') load_path_ema = load_path.replace('.pth', '_ema.pth')
if self.is_train: if self.is_train and self.do_emas:
ema_model = self.emas[name] ema_model = self.emas[name]
if os.path.exists(load_path_ema): if os.path.exists(load_path_ema):
self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}'])) self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
@ -472,7 +475,8 @@ class ExtensibleTrainer(BaseModel):
# Don't save non-trainable networks. # Don't save non-trainable networks.
if self.opt['networks'][name]['trainable']: if self.opt['networks'][name]['trainable']:
self.save_network(net, name, iter_step) self.save_network(net, name, iter_step)
self.save_network(self.emas[name], f'{name}_ema', iter_step) if self.do_emas:
self.save_network(self.emas[name], f'{name}_ema', iter_step)
def force_restore_swapout(self): def force_restore_swapout(self):
# Legacy method. Do nothing. # Legacy method. Do nothing.