diff --git a/codes/train.py b/codes/train.py index 812b4b56..d6142657 100644 --- a/codes/train.py +++ b/codes/train.py @@ -215,7 +215,7 @@ class Trainer: #### save models and training states if self.current_step % opt['logger']['save_checkpoint_freq'] == 0: - self.model.consolidate_state() + self.model.consolidate_state(state) if self.rank <= 0: self.logger.info('Saving models and training states.') self.model.save(self.current_step) diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 4ea7edae..56f33603 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -129,7 +129,7 @@ class BaseModel(): network.load_state_dict(load_net_clean, strict=strict) - def consolidate_state(self): + def consolidate_state(self, state): for o in self.optimizers: if isinstance(o, ZeroRedundancyOptimizer): state['optimizers'].append(o.consolidate_state_dict(to=0))