forked from mrq/DL-Art-School
angry
This commit is contained in:
parent
cc0d9f7216
commit
fc09cff4b3
|
@ -215,7 +215,7 @@ class Trainer:
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||||
self.model.consolidate_state()
|
self.model.consolidate_state(state)
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
self.logger.info('Saving models and training states.')
|
self.logger.info('Saving models and training states.')
|
||||||
self.model.save(self.current_step)
|
self.model.save(self.current_step)
|
||||||
|
|
|
@ -129,7 +129,7 @@ class BaseModel():
|
||||||
network.load_state_dict(load_net_clean, strict=strict)
|
network.load_state_dict(load_net_clean, strict=strict)
|
||||||
|
|
||||||
|
|
||||||
def consolidate_state(self):
|
def consolidate_state(self, state):
|
||||||
for o in self.optimizers:
|
for o in self.optimizers:
|
||||||
if isinstance(o, ZeroRedundancyOptimizer):
|
if isinstance(o, ZeroRedundancyOptimizer):
|
||||||
state['optimizers'].append(o.consolidate_state_dict(to=0))
|
state['optimizers'].append(o.consolidate_state_dict(to=0))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user