consolidate state

This commit is contained in:
James Betker 2022-01-24 17:59:31 -07:00
parent dfef34ba39
commit 3a9e3a9db3
2 changed files with 9 additions and 0 deletions

View File

@ -215,6 +215,7 @@ class Trainer:
#### save models and training states #### save models and training states
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0: if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
self.model.consolidate_state()
if self.rank <= 0: if self.rank <= 0:
self.logger.info('Saving models and training states.') self.logger.info('Saving models and training states.')
self.model.save(self.current_step) self.model.save(self.current_step)

View File

@ -2,6 +2,7 @@ import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel.distributed import DistributedDataParallel
import utils.util import utils.util
@ -127,6 +128,13 @@ class BaseModel():
load_net_clean[k] = v load_net_clean[k] = v
network.load_state_dict(load_net_clean, strict=strict) network.load_state_dict(load_net_clean, strict=strict)
def consolidate_state(self, state):
for o in self.optimizers:
if isinstance(o, ZeroRedundancyOptimizer):
state['optimizers'].append(o.consolidate_state_dict(to=0))
def save_training_state(self, state): def save_training_state(self, state):
"""Save training state during training, which will be used for resuming""" """Save training state during training, which will be used for resuming"""
state.update({'schedulers': [], 'optimizers': []}) state.update({'schedulers': [], 'optimizers': []})