From 798ed7730a940d6c60b70bfd64a7673624ef5f17 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 24 Jan 2022 18:12:08 -0700 Subject: [PATCH] i like wasting time --- codes/train.py | 2 +- codes/trainer/base_model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codes/train.py b/codes/train.py index d6142657..812b4b56 100644 --- a/codes/train.py +++ b/codes/train.py @@ -215,7 +215,7 @@ class Trainer: #### save models and training states if self.current_step % opt['logger']['save_checkpoint_freq'] == 0: - self.model.consolidate_state(state) + self.model.consolidate_state() if self.rank <= 0: self.logger.info('Saving models and training states.') self.model.save(self.current_step) diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 56f33603..ba635f90 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -129,10 +129,10 @@ class BaseModel(): network.load_state_dict(load_net_clean, strict=strict) - def consolidate_state(self, state): + def consolidate_state(self): for o in self.optimizers: if isinstance(o, ZeroRedundancyOptimizer): - state['optimizers'].append(o.consolidate_state_dict(to=0)) + o.consolidate_state_dict(to=0) def save_training_state(self, state):