forked from mrq/DL-Art-School
Load ema to cpu memory if specified
This commit is contained in:
parent
49edffb6ad
commit
dfef34ba39
|
@ -45,6 +45,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.env['mega_batch_factor'] = self.mega_batch_factor
|
self.env['mega_batch_factor'] = self.mega_batch_factor
|
||||||
self.batch_factor = self.mega_batch_factor
|
self.batch_factor = self.mega_batch_factor
|
||||||
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
|
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
|
||||||
|
# It is advantageous for large networks to do this to save an extra copy of the model weights.
|
||||||
|
# It does come at the cost of a round trip to CPU memory at every batch.
|
||||||
|
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
|
||||||
self.checkpointing_cache = opt['checkpointing_enabled']
|
self.checkpointing_cache = opt['checkpointing_enabled']
|
||||||
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
||||||
|
|
||||||
|
@ -147,6 +150,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.networks[k] = dnet
|
self.networks[k] = dnet
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.emas[k] = copy.deepcopy(v)
|
self.emas[k] = copy.deepcopy(v)
|
||||||
|
if self.ema_on_cpu:
|
||||||
|
self.emas[k] = self.emas[k].cpu()
|
||||||
found += 1
|
found += 1
|
||||||
assert found == len(self.netsG) + len(self.netsD)
|
assert found == len(self.netsG) + len(self.netsD)
|
||||||
|
|
||||||
|
@ -316,6 +321,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
ema_params = self.emas[name].parameters()
|
ema_params = self.emas[name].parameters()
|
||||||
net_params = net.parameters()
|
net_params = net.parameters()
|
||||||
for ep, np in zip(ema_params, net_params):
|
for ep, np in zip(ema_params, net_params):
|
||||||
|
if self.ema_on_cpu:
|
||||||
|
np = np.cpu()
|
||||||
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
||||||
[e.after_optimize(state) for e in self.experiments]
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
|
@ -443,6 +450,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
else:
|
else:
|
||||||
print("WARNING! Unable to find EMA network! Starting a new EMA from given model parameters.")
|
print("WARNING! Unable to find EMA network! Starting a new EMA from given model parameters.")
|
||||||
self.emas[name] = copy.deepcopy(net)
|
self.emas[name] = copy.deepcopy(net)
|
||||||
|
if self.ema_on_cpu:
|
||||||
|
self.emas[name] = self.emas[name].cpu()
|
||||||
if hasattr(net.module, 'network_loaded'):
|
if hasattr(net.module, 'network_loaded'):
|
||||||
net.module.network_loaded()
|
net.module.network_loaded()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user