Add batch_size_optimizer support

This commit is contained in:
James Betker 2022-02-08 23:51:31 -07:00
parent 9e9ae328f2
commit 18938248e4
2 changed files with 134 additions and 56 deletions

View File

@ -9,6 +9,7 @@ import torch.nn as nn
import trainer.lr_scheduler as lr_scheduler
import trainer.networks as networks
from trainer.base_model import BaseModel
from trainer.batch_size_optimizer import create_batch_size_optimizer
from trainer.inject import create_injector
from trainer.steps import ConfigurableStep
from trainer.experiments.experiments import get_experiment_for_name
@ -20,6 +21,12 @@ from utils.util import opt_get, denormalize
logger = logging.getLogger('base')
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
class OverwrittenStateError(Exception):
def __init__(self, k, keys):
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
f'immutable and keys should not be overwritten. Current keys: {keys}')
class ExtensibleTrainer(BaseModel):
def __init__(self, opt, cached_networks={}):
super(ExtensibleTrainer, self).__init__(opt)
@ -50,6 +57,7 @@ class ExtensibleTrainer(BaseModel):
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
self.checkpointing_cache = opt['checkpointing_enabled']
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
self.batch_size_optimizer = create_batch_size_optimizer(train_opt)
self.netsG = {}
self.netsD = {}
@ -218,27 +226,27 @@ class ExtensibleTrainer(BaseModel):
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
def optimize_parameters(self, step, optimize=True):
def optimize_parameters(self, it, optimize=True):
# Some models need to make parametric adjustments per-step. Do that here.
for net in self.networks.values():
if hasattr(net.module, "update_for_step"):
net.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
net.module.update_for_step(it, os.path.join(self.opt['path']['models'], ".."))
# Iterate through the steps, performing them one at a time.
state = self.dstate
for step_num, s in enumerate(self.steps):
for step_num, step in enumerate(self.steps):
train_step = True
# 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps.
# Note that the injection points for the step might still be required, so address this by setting train_step=False
if 'every' in s.step_opt.keys() and step % s.step_opt['every'] != 0:
if 'every' in step.step_opt.keys() and it % step.step_opt['every'] != 0:
train_step = False
# Steps can opt out of early (or late) training, make sure that happens here.
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
if 'after' in step.step_opt.keys() and it < step.step_opt['after'] or 'before' in step.step_opt.keys() and it > step.step_opt['before']:
continue
# Steps can choose to not execute if a state key is missing.
if 'requires' in s.step_opt.keys():
if 'requires' in step.step_opt.keys():
requirements_met = True
for requirement in s.step_opt['requires']:
for requirement in step.step_opt['requires']:
if requirement not in state.keys():
requirements_met = False
if not requirements_met:
@ -246,17 +254,17 @@ class ExtensibleTrainer(BaseModel):
if train_step:
# Only set requires_grad=True for the network being trained.
nets_to_train = s.get_networks_trained()
nets_to_train = step.get_networks_trained()
enabled = 0
for name, net in self.networks.items():
net_enabled = name in nets_to_train
if net_enabled:
enabled += 1
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']:
if 'after' in self.opt['networks'][name].keys() and it < self.opt['networks'][name]['after']:
net_enabled = False
for p in net.parameters():
do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and step < p.DO_NOT_TRAIN_UNTIL)
do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and it < p.DO_NOT_TRAIN_UNTIL)
if p.dtype != torch.int64 and p.dtype != torch.bool and not do_not_train_flag:
p.requires_grad = net_enabled
else:
@ -266,13 +274,14 @@ class ExtensibleTrainer(BaseModel):
# Update experiments
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
for o in s.get_optimizers():
for o in step.get_optimizers():
o.zero_grad()
# Now do a forward and backward pass for each gradient accumulation step.
new_states = {}
self.batch_size_optimizer.focus(step.get_optimizers()[-1])
for m in range(self.batch_factor):
ns = s.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
for k, v in ns.items():
if k not in new_states.keys():
new_states[k] = [v]
@ -281,54 +290,17 @@ class ExtensibleTrainer(BaseModel):
# Push the detached new state tensors into the state map for use with the next step.
for k, v in new_states.items():
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
class OverwrittenStateError(Exception):
def __init__(self, k, keys):
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
f'immutable and keys should not be overwritten. Current keys: {keys}')
if k in state.keys():
raise OverwrittenStateError(k, list(state.keys()))
state[k] = v
if train_step and optimize:
# And finally perform optimization.
[e.before_optimize(state) for e in self.experiments]
s.do_step(step)
if s.nan_counter > 10:
if self.auto_recover is None:
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
self.save(step)
self.save_training_state({'iter': step})
raise ArithmeticError
else:
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
for k, ps in self.save_history.keys():
if len(ps) < self.auto_recover:
print("Belay that - not enough saves were recorded. Failing instead.")
raise ArithmeticError
if k == '__state__':
self.resume_training(torch.load(ps[-self.auto_recover]))
else:
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
# Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items():
if hasattr(net, "custom_optimizer_step"):
net.custom_optimizer_step(step)
ema_params = self.emas[name].parameters()
net_params = net.parameters()
for ep, np in zip(ema_params, net_params):
if self.ema_on_cpu:
np = np.cpu()
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
[e.after_optimize(state) for e in self.experiments]
# (Maybe) perform a step.
if train_step and optimize and self.batch_size_optimizer.should_step(it):
self.consume_gradients(state, step, it)
# Record visual outputs for usage in debugging and testing.
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and it % self.opt['logger']['visual_debug_rate'] == 0:
def fix_image(img):
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
img = img.unsqueeze(dim=1)
@ -351,17 +323,54 @@ class ExtensibleTrainer(BaseModel):
for rvi in self.opt['logger']['recurrent_visual_indices']:
rdbgv = fix_image(dbgv[:, rvi])
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (it, rvi, i)))
else:
dbgv = fix_image(dbgv)
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (it, i)))
# Some models have their own specific visual debug routines.
for net_name, net in self.networks.items():
if hasattr(net.module, "visual_dbg"):
model_vdbg_dir = os.path.join(sample_save_path, net_name)
os.makedirs(model_vdbg_dir, exist_ok=True)
net.module.visual_dbg(step, model_vdbg_dir)
net.module.visual_dbg(it, model_vdbg_dir)
def consume_gradients(self, state, step, it):
[e.before_optimize(state) for e in self.experiments]
step.do_step(it)
if step.nan_counter > 10:
if self.auto_recover is None:
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
self.save(it)
self.save_training_state({'iter': it})
raise ArithmeticError
else:
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
for k, ps in self.save_history.keys():
if len(ps) < self.auto_recover:
print("Belay that - not enough saves were recorded. Failing instead.")
raise ArithmeticError
if k == '__state__':
self.resume_training(torch.load(ps[-self.auto_recover]))
else:
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
# Call into custom step hooks as well as update EMA params.
for name, net in self.networks.items():
if hasattr(net, "custom_optimizer_step"):
net.custom_optimizer_step(it)
ema_params = self.emas[name].parameters()
net_params = net.parameters()
for ep, np in zip(ema_params, net_params):
if self.ema_on_cpu:
np = np.cpu()
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
[e.after_optimize(state) for e in self.experiments]
def test(self):
for net in self.netsG.values():
@ -416,6 +425,9 @@ class ExtensibleTrainer(BaseModel):
for o in self.optimizers:
for pgi, pg in enumerate(o.param_groups):
log['learning_rate_%s_%i' % (o._config['network'], pgi)] = pg['lr']
# The batch size optimizer also outputs loggable data.
log.update(self.batch_size_optimizer.get_statistics())
return log
def get_current_visuals(self, need_GT=True):

View File

@ -0,0 +1,66 @@
import torch
from utils.util import opt_get
def create_batch_size_optimizer(opt_train):
if 'batch_size_optimizer' in opt_train.keys():
if opt_train['batch_size_optimizer']['type'] == 'gradient_direction':
return GradientDirectionOptimizer(opt_train)
return MegabatchBatchSizeOptimizer(opt_train)
# Base class for BatchSizeOptimizers.
class BatchSizeOptimizer:
def focus(self, optimizer):
pass
def should_step(self, it):
raise NotImplementedError
def get_statistics(self):
return {}
# BatchSizeOptimizer that just steps every megabatch.
class MegabatchBatchSizeOptimizer(BatchSizeOptimizer):
def __init__(self, opt_train):
pass
def should_step(self, it):
return True
# BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step.
# Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf
class GradientDirectionOptimizer(BatchSizeOptimizer):
def __init__(self, opt_train):
self.mbf = opt_train['mega_batch_factor']
self.opt = opt_train['batch_size_optimizer']
self.max_full_batches = opt_get(self.opt, ['max_full_batches'], 10)
self.parameters_to_poll = opt_get(self.opt, ['poll_parameters'], 8)
self.recalculate_directions_every = opt_get(self.opt, ['recalculate_directions_steps'], 1)
self.last_number_iterations = 0
def vector_angle(self, v1, v2):
with torch.no_grad():
v1 = v1.flatten()
v2 = v2.flatten()
v1_norm = (v1 ** 2).sum().sqrt()
v2_norm = (v2 ** 2).sum().sqrt()
angle = torch.arccos((v1 * v2) / (v1_norm * v2_norm))
return angle
def focus(self, optimizer):
optimizer._gradient_direction_optimizer_params = []
optimizer._gradient_direction_optimizer_prior_directions = []
optimizer._gradient_direction_optimizer_prior_grads = []
optimizer._gradient_direction_optimizer_direction_change_magnitudes = []
optimizer._gradient_direction_optimizer_step = 0
self.current_opt = optimizer
def should_step(self, it):
self.last_number_iterations += 1
def get_statistics(self):
return {"last_number_iterations_before_step": self.last_number_iterations}