This commit is contained in:
James Betker 2022-01-24 18:09:29 -07:00
parent cc0d9f7216
commit fc09cff4b3
2 changed files with 2 additions and 2 deletions

View File

@ -215,7 +215,7 @@ class Trainer:
#### save models and training states
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
self.model.consolidate_state()
self.model.consolidate_state(state)
if self.rank <= 0:
self.logger.info('Saving models and training states.')
self.model.save(self.current_step)

View File

@ -129,7 +129,7 @@ class BaseModel():
network.load_state_dict(load_net_clean, strict=strict)
def consolidate_state(self):
def consolidate_state(self, state):
for o in self.optimizers:
if isinstance(o, ZeroRedundancyOptimizer):
state['optimizers'].append(o.consolidate_state_dict(to=0))