diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 2e4f11d8..5886a711 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -54,6 +54,7 @@ class ExtensibleTrainer(BaseModel): self.ema_rate = opt_get(train_opt, ['ema_rate'], .999) # It is advantageous for large networks to do this to save an extra copy of the model weights. # It does come at the cost of a round trip to CPU memory at every batch. + self.do_emas = opt_get(train_opt, ['ema_enabled'], True) self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False) self.checkpointing_cache = opt['checkpointing_enabled'] self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None) @@ -156,7 +157,7 @@ class ExtensibleTrainer(BaseModel): if v == dnet.module: net_dict[k] = dnet self.networks[k] = dnet - if self.is_train: + if self.is_train and self.do_emas: self.emas[k] = copy.deepcopy(v) if self.ema_on_cpu: self.emas[k] = self.emas[k].cpu() @@ -357,18 +358,20 @@ class ExtensibleTrainer(BaseModel): else: if k in self.networks.keys(): # This isn't always the case, for example for EMAs. self.load_network(ps[-self.auto_recover], self.networks[k], strict=True) - self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True) + if self.do_emas: + self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True) # Call into custom step hooks as well as update EMA params. for name, net in self.networks.items(): if hasattr(net, "custom_optimizer_step"): net.custom_optimizer_step(it) - ema_params = self.emas[name].parameters() - net_params = net.parameters() - for ep, np in zip(ema_params, net_params): - if self.ema_on_cpu: - np = np.cpu() - ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate) + if self.do_emas: + ema_params = self.emas[name].parameters() + net_params = net.parameters() + for ep, np in zip(ema_params, net_params): + if self.ema_on_cpu: + np = np.cpu() + ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate) [e.after_optimize(state) for e in self.experiments] @@ -455,7 +458,7 @@ class ExtensibleTrainer(BaseModel): logger.info('Loading model for [%s]' % (load_path,)) self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}'])) load_path_ema = load_path.replace('.pth', '_ema.pth') - if self.is_train: + if self.is_train and self.do_emas: ema_model = self.emas[name] if os.path.exists(load_path_ema): self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}'])) @@ -472,7 +475,8 @@ class ExtensibleTrainer(BaseModel): # Don't save non-trainable networks. if self.opt['networks'][name]['trainable']: self.save_network(net, name, iter_step) - self.save_network(self.emas[name], f'{name}_ema', iter_step) + if self.do_emas: + self.save_network(self.emas[name], f'{name}_ema', iter_step) def force_restore_swapout(self): # Legacy method. Do nothing.