diff --git a/codes/train.py b/codes/train.py index 2846fe09..812b4b56 100644 --- a/codes/train.py +++ b/codes/train.py @@ -215,6 +215,7 @@ class Trainer: #### save models and training states if self.current_step % opt['logger']['save_checkpoint_freq'] == 0: + self.model.consolidate_state() if self.rank <= 0: self.logger.info('Saving models and training states.') self.model.save(self.current_step) diff --git a/codes/trainer/base_model.py b/codes/trainer/base_model.py index 329140ad..56f33603 100644 --- a/codes/trainer/base_model.py +++ b/codes/trainer/base_model.py @@ -2,6 +2,7 @@ import os from collections import OrderedDict import torch import torch.nn as nn +from torch.distributed.optim import ZeroRedundancyOptimizer from torch.nn.parallel.distributed import DistributedDataParallel import utils.util @@ -127,6 +128,13 @@ class BaseModel(): load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) + + def consolidate_state(self, state): + for o in self.optimizers: + if isinstance(o, ZeroRedundancyOptimizer): + state['optimizers'].append(o.consolidate_state_dict(to=0)) + + def save_training_state(self, state): """Save training state during training, which will be used for resuming""" state.update({'schedulers': [], 'optimizers': []})