From dfef34ba39890da4c3e857603768473a231ca39e Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 24 Jan 2022 15:08:29 -0700 Subject: [PATCH] Load ema to cpu memory if specified --- codes/trainer/ExtensibleTrainer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 7e2d56f8..595e4921 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -45,6 +45,9 @@ class ExtensibleTrainer(BaseModel): self.env['mega_batch_factor'] = self.mega_batch_factor self.batch_factor = self.mega_batch_factor self.ema_rate = opt_get(train_opt, ['ema_rate'], .999) + # It is advantageous for large networks to do this to save an extra copy of the model weights. + # It does come at the cost of a round trip to CPU memory at every batch. + self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False) self.checkpointing_cache = opt['checkpointing_enabled'] self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None) @@ -147,6 +150,8 @@ class ExtensibleTrainer(BaseModel): self.networks[k] = dnet if self.is_train: self.emas[k] = copy.deepcopy(v) + if self.ema_on_cpu: + self.emas[k] = self.emas[k].cpu() found += 1 assert found == len(self.netsG) + len(self.netsD) @@ -316,6 +321,8 @@ class ExtensibleTrainer(BaseModel): ema_params = self.emas[name].parameters() net_params = net.parameters() for ep, np in zip(ema_params, net_params): + if self.ema_on_cpu: + np = np.cpu() ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate) [e.after_optimize(state) for e in self.experiments] @@ -443,6 +450,8 @@ class ExtensibleTrainer(BaseModel): else: print("WARNING! Unable to find EMA network! Starting a new EMA from given model parameters.") self.emas[name] = copy.deepcopy(net) + if self.ema_on_cpu: + self.emas[name] = self.emas[name].cpu() if hasattr(net.module, 'network_loaded'): net.module.network_loaded()