consolidate state
This commit is contained in:
parent
dfef34ba39
commit
3a9e3a9db3
|
@ -215,6 +215,7 @@ class Trainer:
|
||||||
|
|
||||||
#### save models and training states
|
#### save models and training states
|
||||||
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
if self.current_step % opt['logger']['save_checkpoint_freq'] == 0:
|
||||||
|
self.model.consolidate_state()
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
self.logger.info('Saving models and training states.')
|
self.logger.info('Saving models and training states.')
|
||||||
self.model.save(self.current_step)
|
self.model.save(self.current_step)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||||
|
|
||||||
import utils.util
|
import utils.util
|
||||||
|
@ -127,6 +128,13 @@ class BaseModel():
|
||||||
load_net_clean[k] = v
|
load_net_clean[k] = v
|
||||||
network.load_state_dict(load_net_clean, strict=strict)
|
network.load_state_dict(load_net_clean, strict=strict)
|
||||||
|
|
||||||
|
|
||||||
|
def consolidate_state(self, state):
|
||||||
|
for o in self.optimizers:
|
||||||
|
if isinstance(o, ZeroRedundancyOptimizer):
|
||||||
|
state['optimizers'].append(o.consolidate_state_dict(to=0))
|
||||||
|
|
||||||
|
|
||||||
def save_training_state(self, state):
|
def save_training_state(self, state):
|
||||||
"""Save training state during training, which will be used for resuming"""
|
"""Save training state during training, which will be used for resuming"""
|
||||||
state.update({'schedulers': [], 'optimizers': []})
|
state.update({'schedulers': [], 'optimizers': []})
|
||||||
|
|
Loading…
Reference in New Issue
Block a user