DL-Art-School/codes/trainer/base_model.py
2022-03-04 12:46:41 -07:00

196 lines
7.9 KiB
Python

import os
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel.distributed import DistributedDataParallel
import utils.util
from utils.util import opt_get, optimizer_to
class BaseModel():
def __init__(self, opt):
self.opt = opt
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
self.device = torch.cuda.current_device() if opt['gpu_ids'] else torch.device('cpu')
self.amp_level = 'O0' if opt['amp_opt_level'] is None else opt['amp_opt_level']
self.is_train = opt['is_train']
self.opt_in_cpu = opt_get(opt, ['keep_optimizer_states_on_cpu'], False)
self.schedulers = []
self.optimizers = []
self.disc_optimizers = []
self.save_history = {}
def feed_data(self, data):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
pass
def get_current_losses(self):
pass
def print_network(self):
pass
def save(self, label):
pass
def load(self):
pass
def _set_lr(self, lr_groups_l):
"""Set learning rate for warmup
lr_groups_l: list for lr_groups. each for a optimizer"""
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
for param_group, lr in zip(optimizer.param_groups, lr_groups):
param_group['lr'] = lr
def _get_init_lr(self):
"""Get the initial lr, which is set by the scheduler"""
init_lr_groups_l = []
for optimizer in self.optimizers:
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
return init_lr_groups_l
def update_learning_rate(self, cur_iter, warmup_iter=-1):
for scheduler in self.schedulers:
scheduler.last_epoch = cur_iter
scheduler.step()
# set up warm-up learning rate
if cur_iter < warmup_iter:
# get initial lr for each group
init_lr_g_l = self._get_init_lr()
# modify warming-up learning rates
warm_up_lr_l = []
for init_lr_g in init_lr_g_l:
warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
# set learning rate
self._set_lr(warm_up_lr_l)
def get_current_learning_rate(self):
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
def get_network_description(self, network):
"""Get the string and total parameters of the network"""
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
return str(network), sum(map(lambda x: x.numel(), network.parameters()))
def save_network(self, network, network_label, iter_label):
save_filename = '{}_{}.pth'.format(iter_label, network_label)
save_path = os.path.join(self.opt['path']['models'], save_filename)
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
state_dict = network.state_dict()
for key, param in state_dict.items():
state_dict[key] = param.cpu()
torch.save(state_dict, save_path)
if network_label not in self.save_history.keys():
self.save_history[network_label] = []
self.save_history[network_label].append(save_path)
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
if 'alt_path' in self.opt['path'].keys():
torch.save(state_dict, os.path.join(self.opt['path']['alt_path'], save_filename))
if self.opt['colab_mode']:
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
save_path, os.path.join(self.opt['remote_path'], 'models', save_filename))
return save_path
def load_network(self, load_path, network, strict=True, pretrain_base_path=None):
# Sometimes networks are passed in as DDP modules, we want the raw parameters.
if hasattr(network, 'module'):
network = network.module
load_net = torch.load(load_path, map_location=utils.util.map_cuda_to_correct_device)
# Support loading torch.save()s for whole models as well as just state_dicts.
if 'state_dict' in load_net:
load_net = load_net['state_dict']
load_net_clean = OrderedDict() # remove unnecessary 'module.'
if pretrain_base_path is not None:
t = load_net
load_net = {}
for k, v in t.items():
if k.startswith(pretrain_base_path):
load_net[k[len(pretrain_base_path):]] = v
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k.replace('module.', '')] = v
else:
load_net_clean[k] = v
network.load_state_dict(load_net_clean, strict=strict)
def consolidate_state(self):
for o in self.optimizers:
if isinstance(o, ZeroRedundancyOptimizer):
o.consolidate_state_dict(to=0)
def save_training_state(self, state):
"""Save training state during training, which will be used for resuming"""
state.update({'schedulers': [], 'optimizers': []})
for s in self.schedulers:
state['schedulers'].append(s.state_dict())
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
if 'amp_opt_level' in self.opt.keys():
state['amp'] = amp.state_dict()
save_filename = '{}.state'.format(utils.util.opt_get(state, ['iter'], 'no_step_provided'))
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
torch.save(state, save_path)
if '__state__' not in self.save_history.keys():
self.save_history['__state__'] = []
self.save_history['__state__'].append(save_path)
# Also save to the 'alt_path' which is useful for caching to Google Drive in colab, for example.
if 'alt_path' in self.opt['path'].keys():
torch.save(state, os.path.join(self.opt['path']['alt_path'], 'latest.state'))
if self.opt['colab_mode']:
utils.util.copy_files_to_server(self.opt['ssh_server'], self.opt['ssh_username'], self.opt['ssh_password'],
save_path, os.path.join(self.opt['remote_path'], 'training_state', save_filename))
def stash_optimizers(self):
"""
When enabled, puts all optimizer states in CPU memory, allowing forward and backward passes more memory
headroom.
"""
if not self.opt_in_cpu:
return
for opt in self.optimizers:
optimizer_to(opt, 'cpu')
def restore_optimizers(self):
"""
Puts optimizer states back into device memory.
"""
if not self.opt_in_cpu:
return
for opt in self.optimizers:
optimizer_to(opt, self.device)
def resume_training(self, resume_state, load_amp=True):
"""Resume the optimizers and schedulers for training"""
resume_optimizers = resume_state['optimizers']
resume_schedulers = resume_state['schedulers']
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
for i, o in enumerate(resume_optimizers):
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
self.schedulers[i].load_state_dict(s)
if load_amp and 'amp' in resume_state.keys():
from apex import amp
amp.load_state_dict(resume_state['amp'])
self.stash_optimizers()