consolidate state
This commit is contained in:
parent
dfef34ba39
commit
3a9e3a9db3
|
@ -215,6 +215,7 @@ class Trainer:
|
|||
|
||||
#### save models and training states
|
||||
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||
self.model.consolidate_state()
|
||||
if self.rank <= 0:
|
||||
self.logger.info('Saving models and training states.')
|
||||
self.model.save(self.current_step)
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
import utils.util
|
||||
|
@ -127,6 +128,13 @@ class BaseModel():
|
|||
load_net_clean[k] = v
|
||||
network.load_state_dict(load_net_clean, strict=strict)
|
||||
|
||||
|
||||
def consolidate_state(self, state):
|
||||
for o in self.optimizers:
|
||||
if isinstance(o, ZeroRedundancyOptimizer):
|
||||
state['optimizers'].append(o.consolidate_state_dict(to=0))
|
||||
|
||||
|
||||
def save_training_state(self, state):
|
||||
"""Save training state during training, which will be used for resuming"""
|
||||
state.update({'schedulers': [], 'optimizers': []})
|
||||
|
|
Loading…
Reference in New Issue
Block a user