DL-Art-School/codes/trainer/ExtensibleTrainer.py
James Betker 7a36668870 whoops!
2022-06-12 21:11:34 -06:00

539 lines
26 KiB
Python

import copy
import logging
import os
from time import time
import torch
from torch import distributed
from torch.nn.parallel import DataParallel
import torch.nn as nn
import trainer.lr_scheduler as lr_scheduler
import trainer.networks as networks
from trainer.base_model import BaseModel
from trainer.batch_size_optimizer import create_batch_size_optimizer
from trainer.inject import create_injector
from trainer.injectors.audio_injectors import normalize_mel
from trainer.steps import ConfigurableStep
from trainer.experiments.experiments import get_experiment_for_name
import torchvision.utils as utils
from utils.loss_accumulator import LossAccumulator, InfStorageLossAccumulator
from utils.util import opt_get, denormalize
logger = logging.getLogger('base')
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
class OverwrittenStateError(Exception):
def __init__(self, k, keys):
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
f'immutable and keys should not be overwritten. Current keys: {keys}')
class ExtensibleTrainer(BaseModel):
def __init__(self, opt, cached_networks={}):
super(ExtensibleTrainer, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
train_opt = opt['train']
# env is used as a global state to store things that subcomponents might need.
self.env = {'device': self.device,
'rank': self.rank,
'opt': opt,
'step': 0,
'dist': opt['dist']
}
if opt['path']['models'] is not None:
self.env['base_path'] = os.path.join(opt['path']['models'])
self.mega_batch_factor = 1
if self.is_train:
self.mega_batch_factor = train_opt['mega_batch_factor']
self.env['mega_batch_factor'] = self.mega_batch_factor
self.batch_factor = self.mega_batch_factor
self.ema_rate = opt_get(train_opt, ['ema_rate'], .999)
# It is advantageous for large networks to do this to save an extra copy of the model weights.
# It does come at the cost of a round trip to CPU memory at every batch.
self.do_emas = opt_get(train_opt, ['ema_enabled'], True)
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
self.checkpointing_cache = opt['checkpointing_enabled']
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
self.batch_size_optimizer = create_batch_size_optimizer(train_opt)
self.netsG = {}
self.netsD = {}
for name, net in opt['networks'].items():
# Trainable is a required parameter, but the default is simply true. Set it here.
if 'trainable' not in net.keys():
net['trainable'] = True
if name in cached_networks.keys():
new_net = cached_networks[name]
else:
new_net = None
if net['type'] == 'generator':
if new_net is None:
new_net = networks.create_model(opt, net, self.netsG).to(self.device)
self.netsG[name] = new_net
elif net['type'] == 'discriminator':
if new_net is None:
new_net = networks.create_model(opt, net, self.netsD).to(self.device)
self.netsD[name] = new_net
else:
raise NotImplementedError("Can only handle generators and discriminators")
if not net['trainable']:
new_net.eval()
if net['wandb_debug'] and self.rank <= 0:
import wandb
wandb.watch(new_net, log='all', log_freq=3)
# Initialize the train/eval steps
self.step_names = []
self.steps = []
for step_name, step in opt['steps'].items():
step = ConfigurableStep(step, self.env)
self.step_names.append(step_name) # This could be an OrderedDict, but it's a PITA to integrate with AMP below.
self.steps.append(step)
# step.define_optimizers() relies on the networks being placed in the env, so put them there. Even though
# they aren't wrapped yet.
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
# Define the optimizers from the steps
for s in self.steps:
s.define_optimizers()
self.optimizers.extend(s.get_optimizers())
if self.is_train:
# Find the optimizers that are using the default scheduler, then build them.
def_opt = []
for s in self.steps:
def_opt.extend(s.get_optimizers_with_default_scheduler())
self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt)
# Set the starting step count for the scheduler.
for sched in self.schedulers:
sched.last_epoch = opt['current_step']
else:
self.schedulers = []
# Wrap networks in distributed shells.
dnets = []
all_networks = [g for g in self.netsG.values()] + [d for d in self.netsD.values()]
for anet in all_networks:
has_any_trainable_params = False
for p in anet.parameters():
if not hasattr(p, 'DO_NOT_TRAIN'):
has_any_trainable_params = True
break
if has_any_trainable_params and opt['dist']:
if opt['dist_backend'] == 'apex':
# Use Apex to enable delay_allreduce, which is compatible with gradient checkpointing.
from apex.parallel import DistributedDataParallel
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
from torch.nn.parallel.distributed import DistributedDataParallel
# Do NOT be tempted to put find_unused_parameters=True here. It will not work when checkpointing is
# used and in a few other cases. But you can try it if you really want.
dnet = DistributedDataParallel(anet, device_ids=[torch.cuda.current_device()],
output_device=torch.cuda.current_device(),
find_unused_parameters=opt_get(opt, ['ddp_find_unused_parameters'], False))
# DDP graphs cannot be used with gradient checkpointing unless you use find_unused_parameters=True,
# which does not work with this trainer (as stated above). However, if the graph is not subject
# to control flow alterations, you can set this option to allow gradient checkpointing. Beware that
# if you are wrong about control flow, DDP will not train all your model parameters! User beware!
if opt_get(opt, ['ddp_static_graph'], False):
dnet._set_static_graph()
else:
dnet = DataParallel(anet, device_ids=[torch.cuda.current_device()])
if self.is_train:
dnet.train()
else:
dnet.eval()
dnets.append(dnet)
# Backpush the wrapped networks into the network dicts. Also build the EMA parameters.
self.networks = {}
self.emas = {}
found = 0
for dnet in dnets:
for net_dict in [self.netsD, self.netsG]:
for k, v in net_dict.items():
if v == dnet.module:
net_dict[k] = dnet
self.networks[k] = dnet
if self.is_train and self.do_emas:
self.emas[k] = copy.deepcopy(v)
if self.ema_on_cpu:
self.emas[k] = self.emas[k].cpu()
found += 1
assert found == len(self.netsG) + len(self.netsD)
# Replace the env networks with the wrapped networks
self.env['generators'] = self.netsG
self.env['discriminators'] = self.netsD
self.env['emas'] = self.emas
self.print_network() # print network
self.load() # load networks from save states as needed
# Load experiments
self.experiments = []
if 'experiments' in opt.keys():
self.experiments = [get_experiment_for_name(e) for e in opt['experiments']]
# Setting this to false triggers SRGAN to call the models update_model() function on the first iteration.
self.updated = True
def feed_data(self, data, step, need_GT=True, perform_micro_batching=True):
self.env['step'] = step
self.batch_factor = self.mega_batch_factor
self.opt['checkpointing_enabled'] = self.checkpointing_cache
# The batch factor can be adjusted on a period to allow known high-memory steps to fit in GPU memory.
if 'train' in self.opt.keys() and \
'mod_batch_factor' in self.opt['train'].keys() and \
self.env['step'] % self.opt['train']['mod_batch_factor_every'] == 0:
self.batch_factor = self.opt['train']['mod_batch_factor']
if self.opt['train']['mod_batch_factor_also_disable_checkpointing']:
self.opt['checkpointing_enabled'] = False
self.eval_state = {}
for o in self.optimizers:
o.zero_grad()
torch.cuda.empty_cache()
sort_key = opt_get(self.opt, ['train', 'sort_key'], None)
if sort_key is not None:
sort_indices = torch.sort(data[sort_key], descending=True).indices
else:
sort_indices = None
batch_factor = self.batch_factor if perform_micro_batching else 1
self.dstate = {}
for k, v in data.items():
if sort_indices is not None:
if isinstance(v, list):
v = [v[i] for i in sort_indices]
else:
v = v[sort_indices]
if isinstance(v, torch.Tensor):
self.dstate[k] = [t.to(self.device) for t in torch.chunk(v, chunks=batch_factor, dim=0)]
if opt_get(self.opt, ['train', 'auto_collate'], False):
for k, v in self.dstate.items():
if f'{k}_lengths' in self.dstate.keys():
for c in range(len(v)):
maxlen = self.dstate[f'{k}_lengths'][c].max()
if len(v[c].shape) == 2:
self.dstate[k][c] = self.dstate[k][c][:, :maxlen]
elif len(v[c].shape) == 3:
self.dstate[k][c] = self.dstate[k][c][:, :, :maxlen]
elif len(v[c].shape) == 4:
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
def optimize_parameters(self, it, optimize=True, return_grad_norms=False):
grad_norms = {}
# Some models need to make parametric adjustments per-step. Do that here.
for net in self.networks.values():
if hasattr(net.module, "update_for_step"):
net.module.update_for_step(it, os.path.join(self.opt['path']['models'], ".."))
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, step in enumerate(self.steps):
train_step = True
# 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps.
# Note that the injection points for the step might still be required, so address this by setting train_step=False
if 'every' in step.step_opt.keys() and it % step.step_opt['every'] != 0:
train_step = False
# Steps can opt out of early (or late) training, make sure that happens here.
if 'after' in step.step_opt.keys() and it < step.step_opt['after'] or 'before' in step.step_opt.keys() and it > step.step_opt['before']:
continue
# Steps can choose to not execute if a state key is missing.
if 'requires' in step.step_opt.keys():
requirements_met = True
for requirement in step.step_opt['requires']:
if requirement not in state.keys():
requirements_met = False
if not requirements_met:
continue
if train_step:
# Only set requires_grad=True for the network being trained.
nets_to_train = step.get_networks_trained()
enabled = 0
for name, net in self.networks.items():
net_enabled = name in nets_to_train
if net_enabled:
enabled += 1
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
if 'after' in self.opt['networks'][name].keys() and it < self.opt['networks'][name]['after']:
net_enabled = False
for p in net.parameters():
do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and it < p.DO_NOT_TRAIN_UNTIL)
if p.dtype != torch.int64 and p.dtype != torch.bool and not do_not_train_flag:
p.requires_grad = net_enabled
else:
p.requires_grad = False
assert enabled == len(nets_to_train)
# Update experiments
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
for o in step.get_optimizers():
o.zero_grad()
# Now do a forward and backward pass for each gradient accumulation step.
new_states = {}
self.batch_size_optimizer.focus(net)
for m in range(self.batch_factor):
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
# Call into post-backward hooks.
for name, net in self.networks.items():
if hasattr(net.module, "after_backward"):
net.module.after_backward(it)
for k, v in ns.items():
if k not in new_states.keys():
new_states[k] = [v]
else:
new_states[k].append(v)
# Push the detached new state tensors into the state map for use with the next step.
for k, v in new_states.items():
if k in state.keys():
raise OverwrittenStateError(k, list(state.keys()))
state[k] = v
# (Maybe) perform a step.
if train_step and optimize and self.batch_size_optimizer.should_step(it):
# Call into pre-step hooks.
for name, net in self.networks.items():
if hasattr(net.module, "before_step"):
net.module.before_step(it)
# Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point)
# This is needed to accurately log the grad norms.
for opt in step.optimizers:
from torch.cuda.amp.grad_scaler import OptState
if step.scaler.is_enabled() and step.scaler._per_optimizer_states[id(opt)]["stage"] is not OptState.UNSCALED:
step.scaler.unscale_(opt)
if return_grad_norms and train_step:
for name in nets_to_train:
model = self.networks[name]
if hasattr(model.module, 'get_grad_norm_parameter_groups'):
pgroups = {f'{name}_{k}': v for k, v in model.module.get_grad_norm_parameter_groups().items()}
else:
pgroups = {f'{name}_all_parameters': list(model.parameters())}
for name in pgroups.keys():
stacked_grads = []
for p in pgroups[name]:
if hasattr(p, 'grad') and p.grad is not None:
stacked_grads.append(torch.norm(p.grad.detach(), 2))
if not stacked_grads:
continue
grad_norms[name] = torch.norm(torch.stack(stacked_grads), 2)
if distributed.is_available() and distributed.is_initialized():
# Gather the metric from all devices if in a distributed setting.
distributed.all_reduce(grad_norms[name], op=distributed.ReduceOp.SUM)
grad_norms[name] /= distributed.get_world_size()
grad_norms[name] = grad_norms[name].cpu()
self.consume_gradients(state, step, it)
# Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and it % self.opt['logger']['visual_debug_rate'] == 0:
def fix_image(img):
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
if img.min() < -2:
img = normalize_mel(img)
img = img.unsqueeze(dim=1)
if img.shape[1] > 3:
img = img[:, :3, :, :]
if opt_get(self.opt, ['logger', 'reverse_n1_to_1'], False):
img = (img + 1) / 2
if opt_get(self.opt, ['logger', 'reverse_imagenet_norm'], False):
img = denormalize(img)
return img
sample_save_path = os.path.join(self.opt['path']['models'], "..", "visual_dbg")
for v in self.opt['logger']['visuals']:
if v not in state.keys():
continue # This can happen for several reasons (ex: 'after' defs), just ignore it.
for i, dbgv in enumerate(state[v]):
if 'recurrent_visual_indices' in self.opt['logger'].keys() and len(dbgv.shape)==5:
for rvi in self.opt['logger']['recurrent_visual_indices']:
rdbgv = fix_image(dbgv[:, rvi])
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (it, rvi, i)))
else:
dbgv = fix_image(dbgv)
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (it, i)))
# Some models have their own specific visual debug routines.
for net_name, net in self.networks.items():
if hasattr(net.module, "visual_dbg"):
model_vdbg_dir = os.path.join(sample_save_path, net_name)
os.makedirs(model_vdbg_dir, exist_ok=True)
net.module.visual_dbg(it, model_vdbg_dir)
return grad_norms
def consume_gradients(self, state, step, it):
[e.before_optimize(state) for e in self.experiments]
self.restore_optimizers()
step.do_step(it)
self.stash_optimizers()
# Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items():
if hasattr(net.module, "after_step"):
net.module.after_step(it)
if self.do_emas:
# When the EMA is on the CPU, only update every 10 steps to save processing time.
if self.ema_on_cpu and it % 10 != 0:
continue
ema_params = self.emas[name].parameters()
net_params = net.parameters()
for ep, np in zip(ema_params, net_params):
ema_rate = self.ema_rate
new_rate = 1 - ema_rate
if self.ema_on_cpu:
np = np.cpu()
ema_rate = ema_rate ** 10 # Because it only happens every 10 steps.
mid = (1 - (ema_rate+new_rate))/2
ema_rate += mid
new_rate += mid
ep.detach().mul_(ema_rate).add_(np, alpha=1 - ema_rate)
[e.after_optimize(state) for e in self.experiments]
def test(self):
for net in self.netsG.values():
net.eval()
accum_metrics = InfStorageLossAccumulator()
with torch.no_grad():
# This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that.
# Or, we run the entire chain of steps in "train" mode and use eval.output_state.
if 'injectors' in self.opt['eval'].keys():
state = {}
for inj in self.opt['eval']['injectors'].values():
# Need to move from mega_batch mode to batch mode (remove chunks)
for k, v in self.dstate.items():
state[k] = v[0]
inj = create_injector(inj, self.env)
state.update(inj(state))
else:
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, s in enumerate(self.steps):
ns = s.do_forward_backward(state, 0, step_num, train=False, loss_accumulator=accum_metrics)
for k, v in ns.items():
state[k] = [v]
self.eval_state = {}
for k, v in state.items():
if isinstance(v, list):
self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v]
else:
self.eval_state[k] = [v.detach().cpu() if isinstance(v, torch.Tensor) else v]
for net in self.netsG.values():
net.train()
return accum_metrics
# Fetches a summary of the log.
def get_current_log(self, step):
log = {}
for s in self.steps:
log.update(s.get_metrics())
for e in self.experiments:
log.update(e.get_log_data())
# Some generators can do their own metric logging.
for net_name, net in self.networks.items():
if hasattr(net.module, "get_debug_values"):
log.update(net.module.get_debug_values(step, net_name))
# Log learning rate (from first param group) too.
for o in self.optimizers:
for pgi, pg in enumerate(o.param_groups):
log['learning_rate_%s_%i' % (o._config['network'], pgi)] = pg['lr']
# The batch size optimizer also outputs loggable data.
log.update(self.batch_size_optimizer.get_statistics())
# In distributed mode, get agreement on all single tensors.
if distributed.is_available() and distributed.is_initialized():
for k, v in log.items():
if not isinstance(v, torch.Tensor):
continue
if len(v.shape) != 1 or v.dtype != torch.float:
continue
distributed.all_reduce(v, op=distributed.ReduceOp.SUM)
log[k] = v / distributed.get_world_size()
return log
def get_current_visuals(self, need_GT=True):
# Conforms to an archaic format from MMSR.
res = {'rlt': self.eval_state[self.opt['eval']['output_state']][0].float().cpu()}
if 'hq' in self.eval_state.keys():
res['hq'] = self.eval_state['hq'][0].float().cpu(),
return res
def print_network(self):
for name, net in self.networks.items():
s, n = self.get_network_description(net)
net_struc_str = '{}'.format(net.__class__.__name__)
if self.rank <= 0:
logger.info('Network {} structure: {}, with parameters: {:,d}'.format(name, net_struc_str, n))
logger.info(s)
def load(self):
for netdict in [self.netsG, self.netsD]:
for name, net in netdict.items():
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is None:
return
if self.rank <= 0:
logger.info('Loading model for [%s]' % (load_path,))
self.load_network(load_path, net, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
load_path_ema = load_path.replace('.pth', '_ema.pth')
if self.is_train and self.do_emas:
ema_model = self.emas[name]
if os.path.exists(load_path_ema):
self.load_network(load_path_ema, ema_model, self.opt['path']['strict_load'], opt_get(self.opt, ['path', f'pretrain_base_path_{name}']))
else:
print("WARNING! Unable to find EMA network! Starting a new EMA from given model parameters.")
self.emas[name] = copy.deepcopy(net)
if self.ema_on_cpu:
self.emas[name] = self.emas[name].cpu()
if hasattr(net.module, 'network_loaded'):
net.module.network_loaded()
def save(self, iter_step):
for name, net in self.networks.items():
# Don't save non-trainable networks.
if self.opt['networks'][name]['trainable']:
self.save_network(net, name, iter_step)
if self.do_emas:
self.save_network(self.emas[name], f'{name}_ema', iter_step)
def force_restore_swapout(self):
# Legacy method. Do nothing.
pass