From fc09cff4b3f1cef976172847cb5c00e854c8b6c4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 24 Jan 2022 18:09:29 -0700 Subject: [PATCH] angry --- codes/train.py | 2 +- codes/trainer/base_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/train.py b/codes/train.py index 812b4b56..d6142657 100644 --- a/codes/train.py +++ b/codes/train.py @@ -215,7 +215,7 @@ class Trainer: #### save models and training states if self.current_step % opt['logger']['save_checkpoint_freq'] == 0: - self.model.consolidate_state() + self.model.consolidate_state(state) if self.rank <= 0: self.logger.info('Saving models and training states.') self.model.save(self.current_step) diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 4ea7edae..56f33603 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -129,7 +129,7 @@ class BaseModel(): network.load_state_dict(load_net_clean, strict=strict) - def consolidate_state(self): + def consolidate_state(self, state): for o in self.optimizers: if isinstance(o, ZeroRedundancyOptimizer): state['optimizers'].append(o.consolidate_state_dict(to=0))