consolidate state

This commit is contained in:
James Betker 2022-01-24 17:59:31 -07:00
parent dfef34ba39
commit 3a9e3a9db3
2 changed files with 9 additions and 0 deletions

View File

@ -215,6 +215,7 @@ class Trainer:
#### save models and training states
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
self.model.consolidate_state()
if self.rank <= 0:
self.logger.info('Saving models and training states.')
self.model.save(self.current_step)

View File

@ -2,6 +2,7 @@ import os
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel.distributed import DistributedDataParallel
import utils.util
@ -127,6 +128,13 @@ class BaseModel():
load_net_clean[k] = v
network.load_state_dict(load_net_clean, strict=strict)
def consolidate_state(self, state):
for o in self.optimizers:
if isinstance(o, ZeroRedundancyOptimizer):
state['optimizers'].append(o.consolidate_state_dict(to=0))
def save_training_state(self, state):
"""Save training state during training, which will be used for resuming"""
state.update({'schedulers': [], 'optimizers': []})