forked from mrq/DL-Art-School
angry
This commit is contained in:
parent
cc0d9f7216
commit
fc09cff4b3
|
@ -215,7 +215,7 @@ class Trainer:
|
|||
|
||||
#### save models and training states
|
||||
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
self.model.consolidate_state()
|
||||
self.model.consolidate_state(state)
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Saving models and training states.')
|
||||
self.model.save(self.current_step)
|
||||
|
|
|
@ -129,7 +129,7 @@ class BaseModel():
|
|||
network.load_state_dict(load_net_clean, strict=strict)
|
||||
|
||||
|
||||
def consolidate_state(self):
|
||||
def consolidate_state(self, state):
|
||||
for o in self.optimizers:
|
||||
if isinstance(o, ZeroRedundancyOptimizer):
|
||||
state['optimizers'].append(o.consolidate_state_dict(to=0))
|
||||
|
|
Loading…
Reference in New Issue
Block a user