forked from mrq/DL-Art-School
Allow EMA training to be disabled
This commit is contained in:
parent
3252972057
commit
15fd60aad3
|
@ -54,6 +54,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
|
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
|
||||||
# It is advantageous for large networks to do this to save an extra copy of the model weights.
|
# It is advantageous for large networks to do this to save an extra copy of the model weights.
|
||||||
# It does come at the cost of a round trip to CPU memory at every batch.
|
# It does come at the cost of a round trip to CPU memory at every batch.
|
||||||
|
self.do_emas = opt_get(train_opt, ['ema_enabled'], True)
|
||||||
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
|
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
|
||||||
self.checkpointing_cache = opt['checkpointing_enabled']
|
self.checkpointing_cache = opt['checkpointing_enabled']
|
||||||
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
||||||
|
@ -156,7 +157,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
if v == dnet.module:
|
if v == dnet.module:
|
||||||
net_dict[k] = dnet
|
net_dict[k] = dnet
|
||||||
self.networks[k] = dnet
|
self.networks[k] = dnet
|
||||||
if self.is_train:
|
if self.is_train and self.do_emas:
|
||||||
self.emas[k] = copy.deepcopy(v)
|
self.emas[k] = copy.deepcopy(v)
|
||||||
if self.ema_on_cpu:
|
if self.ema_on_cpu:
|
||||||
self.emas[k] = self.emas[k].cpu()
|
self.emas[k] = self.emas[k].cpu()
|
||||||
|
@ -357,18 +358,20 @@ class ExtensibleTrainer(BaseModel):
|
||||||
else:
|
else:
|
||||||
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
|
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
|
||||||
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
|
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
|
||||||
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
|
if self.do_emas:
|
||||||
|
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
|
||||||
|
|
||||||
# Call into custom step hooks as well as update EMA params.
|
# Call into custom step hooks as well as update EMA params.
|
||||||
for name, net in self.networks.items():
|
for name, net in self.networks.items():
|
||||||
if hasattr(net, "custom_optimizer_step"):
|
if hasattr(net, "custom_optimizer_step"):
|
||||||
net.custom_optimizer_step(it)
|
net.custom_optimizer_step(it)
|
||||||
ema_params = self.emas[name].parameters()
|
if self.do_emas:
|
||||||
net_params = net.parameters()
|
ema_params = self.emas[name].parameters()
|
||||||
for ep, np in zip(ema_params, net_params):
|
net_params = net.parameters()
|
||||||
if self.ema_on_cpu:
|
for ep, np in zip(ema_params, net_params):
|
||||||
np = np.cpu()
|
if self.ema_on_cpu:
|
||||||
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
np = np.cpu()
|
||||||
|
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
||||||
[e.after_optimize(state) for e in self.experiments]
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
|
|
||||||
|
@ -455,7 +458,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
logger.info('Loading model for [%s]' % (load_path,))
|
logger.info('Loading model for [%s]' % (load_path,))
|
||||||
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
|
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
|
||||||
load_path_ema = load_path.replace('.pth', '_ema.pth')
|
load_path_ema = load_path.replace('.pth', '_ema.pth')
|
||||||
if self.is_train:
|
if self.is_train and self.do_emas:
|
||||||
ema_model = self.emas[name]
|
ema_model = self.emas[name]
|
||||||
if os.path.exists(load_path_ema):
|
if os.path.exists(load_path_ema):
|
||||||
self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
|
self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
|
||||||
|
@ -472,7 +475,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Don't save non-trainable networks.
|
# Don't save non-trainable networks.
|
||||||
if self.opt['networks'][name]['trainable']:
|
if self.opt['networks'][name]['trainable']:
|
||||||
self.save_network(net, name, iter_step)
|
self.save_network(net, name, iter_step)
|
||||||
self.save_network(self.emas[name], f'{name}_ema', iter_step)
|
if self.do_emas:
|
||||||
|
self.save_network(self.emas[name], f'{name}_ema', iter_step)
|
||||||
|
|
||||||
def force_restore_swapout(self):
|
def force_restore_swapout(self):
|
||||||
# Legacy method. Do nothing.
|
# Legacy method. Do nothing.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user