DL-Art-School/codes/trainer/ExtensibleTrainer.py

469 lines
22 KiB
Python
Raw Normal View History

import copy
2020-08-12 14:45:23 +00:00
import logging
2020-08-22 14:24:34 +00:00
import os
2020-08-12 14:45:23 +00:00
import torch
from torch.nn.parallel import DataParallel
2020-08-22 19:08:33 +00:00
import torch.nn as nn
2020-08-22 14:24:34 +00:00
2020-12-18 16:18:34 +00:00
import trainer.lr_scheduler as lr_scheduler
import trainer.networks as networks
from trainer.base_model import BaseModel
2022-02-09 06:51:31 +00:00
from trainer.batch_size_optimizer import create_batch_size_optimizer
2020-12-30 03:58:02 +00:00
from trainer.inject import create_injector
2020-12-18 16:18:34 +00:00
from trainer.steps import ConfigurableStep
from trainer.experiments.experiments import get_experiment_for_name
2020-08-22 19:08:33 +00:00
import torchvision.utils as utils
2020-08-12 14:45:23 +00:00
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
from utils.util import opt_get, denormalize
2020-08-12 14:45:23 +00:00
logger = logging.getLogger('base')
2022-02-09 06:51:31 +00:00
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
class OverwrittenStateError(Exception):
def __init__(self, k, keys):
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
f'immutable and keys should not be overwritten. Current keys: {keys}')
2020-08-12 14:45:23 +00:00
class ExtensibleTrainer(BaseModel):
2021-10-31 02:48:06 +00:00
def __init__(self, opt, cached_networks={}):
2020-08-12 14:45:23 +00:00
super(ExtensibleTrainer, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
2020-08-22 14:24:34 +00:00
# env is used as a global state to store things that subcomponents might need.
self.env = {'device': self.device,
2020-08-22 14:24:34 +00:00
'rank': self.rank,
'opt': opt,
'step': 0,
'dist': opt['dist']
}
if opt['path']['models'] is not None:
self.env['base_path'] = os.path.join(opt['path']['models'])
2020-08-22 14:24:34 +00:00
self.mega_batch_factor = 1
if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
self.env['mega_batch_factor'] = self.mega_batch_factor
self.batch_factor = self.mega_batch_factor
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
2022-01-24 22:08:29 +00:00
# It is advantageous for large networks to do this to save an extra copy of the model weights.
# It does come at the cost of a round trip to CPU memory at every batch.
2022-02-13 03:00:23 +00:00
self.do_emas = opt_get(train_opt, ['ema_enabled'], True)
2022-01-24 22:08:29 +00:00
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
self.checkpointing_cache = opt['checkpointing_enabled']
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
2022-02-09 06:51:31 +00:00
self.batch_size_optimizer = create_batch_size_optimizer(train_opt)
2020-08-18 14:49:32 +00:00
self.netsG = {}
self.netsD = {}
2020-08-12 14:45:23 +00:00
for name, net in opt['networks'].items():
# Trainable is a required parameter, but the default is simply true. Set it here.
if 'trainable' not in net.keys():
net['trainable'] = True
2021-10-31 02:48:06 +00:00
if name in cached_networks.keys():
new_net = cached_networks[name]
else:
new_net = None
2020-08-12 14:45:23 +00:00
if net['type'] == 'generator':
2020-10-22 19:27:32 +00:00
if new_net is None:
2021-03-03 03:51:48 +00:00
new_net = networks.create_model(opt, net, self.netsG).to(self.device)
2020-08-18 14:49:32 +00:00
self.netsG[name] = new_net
2020-08-12 14:45:23 +00:00
elif net['type'] == 'discriminator':
2020-10-22 19:27:32 +00:00
if new_net is None:
2021-03-03 03:51:48 +00:00
new_net = networks.create_model(opt, net, self.netsD).to(self.device)
2020-08-18 14:49:32 +00:00
self.netsD[name] = new_net
2020-08-12 14:45:23 +00:00
else:
raise NotImplementedError("Can only handle generators and discriminators")
if not net['trainable']:
new_net.eval()
2021-06-06 22:52:07 +00:00
if net['wandb_debug'] and self.rank <= 0:
2020-11-12 04:48:56 +00:00
import wandb
wandb.watch(new_net, log='all', log_freq=3)
# Initialize the train/eval steps
2020-09-19 16:05:25 +00:00
self.step_names = []
self.steps = []
for step_name, step in opt['steps'].items():
step = ConfigurableStep(step, self.env)
2020-09-19 16:05:25 +00:00
self.step_names.append(step_name) # This could be an OrderedDict, but it's a PITA to integrate with AMP below.
self.steps.append(step)
2020-08-12 14:45:23 +00:00
2020-09-01 13:58:11 +00:00
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
# they aren't wrapped yet.
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
# Define the optimizers from the steps
for s in self.steps:
s.define_optimizers()
self.optimizers.extend(s.get_optimizers())
if self.is_train:
# Find the optimizers that are using the default scheduler, then build them.
def_opt = []
for s in self.steps:
def_opt.extend(s.get_optimizers_with_default_scheduler())
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
# Set the starting step count for the scheduler.
2021-01-03 05:24:12 +00:00
for sched in self.schedulers:
sched.last_epoch = opt['current_step']
else:
self.schedulers = []
# Wrap networks in distributed shells.
dnets = []
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks:
2022-02-19 01:52:33 +00:00
has_any_trainable_params = False
for p in anet.parameters():
if not hasattr(p, 'DO_NOT_TRAIN'):
has_any_trainable_params = True
break
if has_any_trainable_params and opt['dist']:
if opt['dist_backend'] == 'apex':
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
from apex.parallel import DistributedDataParallel
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
from torch.nn.parallel.distributed import DistributedDataParallel
# Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
# used and in a few other cases. But you can try it if you really want.
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()], find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
2021-12-26 04:20:06 +00:00
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
# which does not work with this trainer (as stated above). However, if the graph is not subject
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
# if you are wrong about control flow, DDP will not train all your model parameters! User beware!
if opt_get(opt, ['ddp_static_graph'], False):
dnet._set_static_graph()
else:
2022-02-20 03:36:35 +00:00
dnet = DataParallel(anet, device_ids=[torch.cuda.current_device()])
if self.is_train:
dnet.train()
else:
dnet.eval()
dnets.append(dnet)
# Backpush the wrapped networks into the network dicts. Also build the EMA parameters.
self.networks = {}
self.emas = {}
found = 0
for dnet in dnets:
for net_dict in [self.netsD, self.netsG]:
for k, v in net_dict.items():
if v == dnet.module:
net_dict[k] = dnet
self.networks[k] = dnet
2022-02-13 03:00:23 +00:00
if self.is_train and self.do_emas:
self.emas[k] = copy.deepcopy(v)
2022-01-24 22:08:29 +00:00
if self.ema_on_cpu:
self.emas[k] = self.emas[k].cpu()
found += 1
assert found == len(self.netsG) + len(self.netsD)
# Replace the env networks with the wrapped networks
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
2021-06-14 15:14:30 +00:00
self.env['emas'] = self.emas
2020-08-12 14:45:23 +00:00
self.print_network() # print network
self.load() # load networks from save states as needed
2020-08-12 14:45:23 +00:00
2020-09-19 16:05:25 +00:00
# Load experiments
self.experiments = []
if 'experiments' in opt.keys():
2022-01-24 21:31:43 +00:00
self.experiments = [get_experiment_for_name(e) for e in opt['experiments']]
2020-09-19 16:05:25 +00:00
2020-08-12 14:45:23 +00:00
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
self.updated = True
2021-08-09 22:02:01 +00:00
def feed_data(self, data, step, need_GT=True, perform_micro_batching=True):
self.env['step'] = step
self.batch_factor = self.mega_batch_factor
self.opt['checkpointing_enabled'] = self.checkpointing_cache
# The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory.
2021-06-11 21:31:10 +00:00
if 'train' in self.opt.keys() and \
'mod_batch_factor' in self.opt['train'].keys() and \
self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0:
self.batch_factor = self.opt['train']['mod_batch_factor']
if self.opt['train']['mod_batch_factor_also_disable_checkpointing']:
self.opt['checkpointing_enabled'] = False
self.eval_state = {}
for o in self.optimizers:
o.zero_grad()
torch.cuda.empty_cache()
sort_key = opt_get(self.opt, ['train', 'sort_key'], None)
if sort_key is not None:
sort_indices = torch.sort(data[sort_key], descending=True).indices
else:
sort_indices = None
2021-08-09 22:02:01 +00:00
batch_factor = self.batch_factor if perform_micro_batching else 1
self.dstate = {}
2020-08-25 17:56:59 +00:00
for k, v in data.items():
if sort_indices is not None:
if isinstance(v, list):
v = [v[i] for i in sort_indices]
else:
v = v[sort_indices]
if isinstance(v, torch.Tensor):
2021-08-09 22:02:01 +00:00
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
2020-08-25 17:56:59 +00:00
if opt_get(self.opt, ['train', 'auto_collate'], False):
for k, v in self.dstate.items():
if f'{k}_lengths' in self.dstate.keys():
for c in range(len(v)):
maxlen = self.dstate[f'{k}_lengths'][c].max()
if len(v[c].shape) == 2:
self.dstate[k][c] = self.dstate[k][c][:, :maxlen]
elif len(v[c].shape) == 3:
self.dstate[k][c] = self.dstate[k][c][:, :, :maxlen]
elif len(v[c].shape) == 4:
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
2022-02-09 06:51:31 +00:00
def optimize_parameters(self, it, optimize=True):
2020-08-12 14:45:23 +00:00
# Some models need to make parametric adjustments per-step. Do that here.
for net in self.networks.values():
2020-08-26 14:44:22 +00:00
if hasattr(net.module, "update_for_step"):
2022-02-09 06:51:31 +00:00
net.module.update_for_step(it, os.path.join(self.opt['path']['models'], ".."))
2020-08-12 14:45:23 +00:00
# Iterate through the steps, performing them one at a time.
2020-08-25 17:56:59 +00:00
state = self.dstate
2022-02-09 06:51:31 +00:00
for step_num, step in enumerate(self.steps):
train_step = True
2020-10-01 17:28:06 +00:00
# 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps.
# Note that the injection points for the step might still be required, so address this by setting train_step=False
2022-02-09 06:51:31 +00:00
if 'every' in step.step_opt.keys() and it % step.step_opt['every'] != 0:
train_step = False
# Steps can opt out of early (or late) training, make sure that happens here.
2022-02-09 06:51:31 +00:00
if 'after' in step.step_opt.keys() and it < step.step_opt['after'] or 'before' in step.step_opt.keys() and it > step.step_opt['before']:
continue
# Steps can choose to not execute if a state key is missing.
2022-02-09 06:51:31 +00:00
if 'requires' in step.step_opt.keys():
requirements_met = True
2022-02-09 06:51:31 +00:00
for requirement in step.step_opt['requires']:
if requirement not in state.keys():
requirements_met = False
if not requirements_met:
continue
if train_step:
# Only set requires_grad=True for the network being trained.
2022-02-09 06:51:31 +00:00
nets_to_train = step.get_networks_trained()
enabled = 0
for name, net in self.networks.items():
net_enabled = name in nets_to_train
if net_enabled:
enabled += 1
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
2022-02-09 06:51:31 +00:00
if 'after' in self.opt['networks'][name].keys() and it < self.opt['networks'][name]['after']:
net_enabled = False
for p in net.parameters():
2022-02-09 06:51:31 +00:00
do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and it < p.DO_NOT_TRAIN_UNTIL)
if p.dtype != torch.int64 and p.dtype != torch.bool and not do_not_train_flag:
p.requires_grad = net_enabled
else:
p.requires_grad = False
assert enabled == len(nets_to_train)
# Update experiments
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
2022-02-09 06:51:31 +00:00
for o in step.get_optimizers():
o.zero_grad()
2020-08-12 14:45:23 +00:00
# Now do a forward and backward pass for each gradient accumulation step.
2020-08-22 14:24:34 +00:00
new_states = {}
self.batch_size_optimizer.focus(net)
for m in range(self.batch_factor):
2022-02-09 06:51:31 +00:00
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
2020-08-22 14:24:34 +00:00
for k, v in ns.items():
if k not in new_states.keys():
new_states[k] = [v]
2020-08-22 14:24:34 +00:00
else:
new_states[k].append(v)
2020-08-22 14:24:34 +00:00
# Push the detached new state tensors into the state map for use with the next step.
for k, v in new_states.items():
2020-12-30 16:51:59 +00:00
if k in state.keys():
raise OverwrittenStateError(k, list(state.keys()))
2020-08-22 14:24:34 +00:00
state[k] = v
2020-08-12 14:45:23 +00:00
2022-02-09 06:51:31 +00:00
# (Maybe) perform a step.
if train_step and optimize and self.batch_size_optimizer.should_step(it):
self.consume_gradients(state, step, it)
2020-08-12 14:45:23 +00:00
2020-08-22 19:08:33 +00:00
# Record visual outputs for usage in debugging and testing.
2022-02-09 06:51:31 +00:00
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and it % self.opt['logger']['visual_debug_rate'] == 0:
def fix_image(img):
2021-07-09 04:13:44 +00:00
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
img = img.unsqueeze(dim=1)
# Normalize so spectrogram is easier to view.
img = (img - img.mean()) / img.std()
if img.shape[1] > 3:
img = img[:, :3, :, :]
if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):
img = (img + 1) / 2
if opt_get(self.opt, ['logger', 'reverse_imagenet_norm'], False):
img = denormalize(img)
return img
2020-08-22 19:08:33 +00:00
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
for v in self.opt['logger']['visuals']:
if v not in state.keys():
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
for i, dbgv in enumerate(state[v]):
2020-10-30 06:19:58 +00:00
if 'recurrent_visual_indices' in self.opt['logger'].keys() and len(dbgv.shape)==5:
for rvi in self.opt['logger']['recurrent_visual_indices']:
rdbgv = fix_image(dbgv[:, rvi])
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
2022-02-09 06:51:31 +00:00
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (it, rvi, i)))
else:
dbgv = fix_image(dbgv)
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
2022-02-09 06:51:31 +00:00
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (it, i)))
# Some models have their own specific visual debug routines.
for net_name, net in self.networks.items():
if hasattr(net.module, "visual_dbg"):
model_vdbg_dir = os.path.join(sample_save_path, net_name)
os.makedirs(model_vdbg_dir, exist_ok=True)
2022-02-09 06:51:31 +00:00
net.module.visual_dbg(it, model_vdbg_dir)
def consume_gradients(self, state, step, it):
[e.before_optimize(state) for e in self.experiments]
step.do_step(it)
# Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items():
if hasattr(net, "custom_optimizer_step"):
net.custom_optimizer_step(it)
2022-02-13 03:00:23 +00:00
if self.do_emas:
ema_params = self.emas[name].parameters()
net_params = net.parameters()
for ep, np in zip(ema_params, net_params):
if self.ema_on_cpu:
np = np.cpu()
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
2022-02-09 06:51:31 +00:00
[e.after_optimize(state) for e in self.experiments]
2020-08-12 14:45:23 +00:00
def test(self):
2020-08-22 19:08:33 +00:00
for net in self.netsG.values():
net.eval()
accum_metrics = InfStorageLossAccumulator()
2020-08-12 14:45:23 +00:00
with torch.no_grad():
# This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that.
# Or, we run the entire chain of steps in "train" mode and use eval.output_state.
if 'injectors' in self.opt['eval'].keys():
state = {}
for inj in self.opt['eval']['injectors'].values():
# Need to move from mega_batch mode to batch mode (remove chunks)
for k, v in self.dstate.items():
state[k] = v[0]
inj = create_injector(inj, self.env)
state.update(inj(state))
else:
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, s in enumerate(self.steps):
ns = s.do_forward_backward(state, 0, step_num, train=False, loss_accumulator=accum_metrics)
for k, v in ns.items():
state[k] = [v]
2020-08-22 19:08:33 +00:00
self.eval_state = {}
for k, v in state.items():
if isinstance(v, list):
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
else:
self.eval_state[k] = [v.detach().cpu() if isinstance(v, torch.Tensor) else v]
2020-08-22 19:08:33 +00:00
for net in self.netsG.values():
net.train()
return accum_metrics
2020-08-12 14:45:23 +00:00
# Fetches a summary of the log.
def get_current_log(self, step):
2020-08-22 19:08:33 +00:00
log = {}
for s in self.steps:
log.update(s.get_metrics())
2020-08-12 14:45:23 +00:00
2020-09-19 16:05:25 +00:00
for e in self.experiments:
log.update(e.get_log_data())
2020-08-12 14:45:23 +00:00
# Some generators can do their own metric logging.
for net_name, net in self.networks.items():
2020-08-22 19:08:33 +00:00
if hasattr(net.module, "get_debug_values"):
log.update(net.module.get_debug_values(step, net_name))
# Log learning rate (from first param group) too.
for o in self.optimizers:
for pgi, pg in enumerate(o.param_groups):
log['learning_rate_%s_%i' % (o._config['network'], pgi)] = pg['lr']
2022-02-09 06:51:31 +00:00
# The batch size optimizer also outputs loggable data.
log.update(self.batch_size_optimizer.get_statistics())
2020-08-22 19:08:33 +00:00
return log
2020-08-12 14:45:23 +00:00
def get_current_visuals(self, need_GT=True):
2020-08-22 19:08:33 +00:00
# Conforms to an archaic format from MMSR.
2021-06-11 21:31:10 +00:00
res = {'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
2021-01-10 03:53:46 +00:00
if 'hq' in self.eval_state.keys():
res['hq'] = self.eval_state['hq'][0].float().cpu(),
return res
2020-08-12 14:45:23 +00:00
def print_network(self):
for name, net in self.networks.items():
2020-08-12 14:45:23 +00:00
s, n = self.get_network_description(net)
net_struc_str = '{}'.format(net.__class__.__name__)
if self.rank <= 0:
logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
2020-08-12 14:45:23 +00:00
logger.info(s)
def load(self):
2020-08-22 19:08:33 +00:00
for netdict in [self.netsG, self.netsD]:
for name, net in netdict.items():
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is None:
return
if self.rank <= 0:
logger.info('Loading model for [%s]' % (load_path,))
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
load_path_ema = load_path.replace('.pth', '_ema.pth')
2022-02-13 03:00:23 +00:00
if self.is_train and self.do_emas:
ema_model = self.emas[name]
if os.path.exists(load_path_ema):
self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
else:
print("WARNING! Unable to find EMA network! Starting a new EMA from given model parameters.")
self.emas[name] = copy.deepcopy(net)
2022-01-24 22:08:29 +00:00
if self.ema_on_cpu:
self.emas[name] = self.emas[name].cpu()
if hasattr(net.module, 'network_loaded'):
net.module.network_loaded()
2020-08-12 14:45:23 +00:00
def save(self, iter_step):
for name, net in self.networks.items():
# Don't save non-trainable networks.
if self.opt['networks'][name]['trainable']:
self.save_network(net, name, iter_step)
2022-02-13 03:00:23 +00:00
if self.do_emas:
self.save_network(self.emas[name], f'{name}_ema', iter_step)
def force_restore_swapout(self):
# Legacy method. Do nothing.
pass