From 18938248e4b367da05c327062d1557aa6ccbd62b Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 8 Feb 2022 23:51:31 -0700 Subject: [PATCH] Add batch_size_optimizer support --- codes/trainer/ExtensibleTrainer.py | 124 ++++++++++++++------------ codes/trainer/batch_size_optimizer.py | 66 ++++++++++++++ 2 files changed, 134 insertions(+), 56 deletions(-) create mode 100644 codes/trainer/batch_size_optimizer.py diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 595e4921..5dab9548 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -9,6 +9,7 @@ import torch.nn as nn import trainer.lr_scheduler as lr_scheduler import trainer.networks as networks from trainer.base_model import BaseModel +from trainer.batch_size_optimizer import create_batch_size_optimizer from trainer.inject import create_injector from trainer.steps import ConfigurableStep from trainer.experiments.experiments import get_experiment_for_name @@ -20,6 +21,12 @@ from utils.util import opt_get, denormalize logger = logging.getLogger('base') +# State is immutable to reduce complexity. Overwriting existing state keys is not supported. +class OverwrittenStateError(Exception): + def __init__(self, k, keys): + super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered ' + f'immutable and keys should not be overwritten. Current keys: {keys}') + class ExtensibleTrainer(BaseModel): def __init__(self, opt, cached_networks={}): super(ExtensibleTrainer, self).__init__(opt) @@ -50,6 +57,7 @@ class ExtensibleTrainer(BaseModel): self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False) self.checkpointing_cache = opt['checkpointing_enabled'] self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None) + self.batch_size_optimizer = create_batch_size_optimizer(train_opt) self.netsG = {} self.netsD = {} @@ -218,27 +226,27 @@ class ExtensibleTrainer(BaseModel): self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen] - def optimize_parameters(self, step, optimize=True): + def optimize_parameters(self, it, optimize=True): # Some models need to make parametric adjustments per-step. Do that here. for net in self.networks.values(): if hasattr(net.module, "update_for_step"): - net.module.update_for_step(step, os.path.join(self.opt['path']['models'], "..")) + net.module.update_for_step(it, os.path.join(self.opt['path']['models'], "..")) # Iterate through the steps, performing them one at a time. state = self.dstate - for step_num, s in enumerate(self.steps): + for step_num, step in enumerate(self.steps): train_step = True # 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps. # Note that the injection points for the step might still be required, so address this by setting train_step=False - if 'every' in s.step_opt.keys() and step % s.step_opt['every'] != 0: + if 'every' in step.step_opt.keys() and it % step.step_opt['every'] != 0: train_step = False # Steps can opt out of early (or late) training, make sure that happens here. - if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']: + if 'after' in step.step_opt.keys() and it < step.step_opt['after'] or 'before' in step.step_opt.keys() and it > step.step_opt['before']: continue # Steps can choose to not execute if a state key is missing. - if 'requires' in s.step_opt.keys(): + if 'requires' in step.step_opt.keys(): requirements_met = True - for requirement in s.step_opt['requires']: + for requirement in step.step_opt['requires']: if requirement not in state.keys(): requirements_met = False if not requirements_met: @@ -246,17 +254,17 @@ class ExtensibleTrainer(BaseModel): if train_step: # Only set requires_grad=True for the network being trained. - nets_to_train = s.get_networks_trained() + nets_to_train = step.get_networks_trained() enabled = 0 for name, net in self.networks.items(): net_enabled = name in nets_to_train if net_enabled: enabled += 1 # Networks can opt out of training before a certain iteration by declaring 'after' in their definition. - if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']: + if 'after' in self.opt['networks'][name].keys() and it < self.opt['networks'][name]['after']: net_enabled = False for p in net.parameters(): - do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and step < p.DO_NOT_TRAIN_UNTIL) + do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and it < p.DO_NOT_TRAIN_UNTIL) if p.dtype != torch.int64 and p.dtype != torch.bool and not do_not_train_flag: p.requires_grad = net_enabled else: @@ -266,13 +274,14 @@ class ExtensibleTrainer(BaseModel): # Update experiments [e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments] - for o in s.get_optimizers(): + for o in step.get_optimizers(): o.zero_grad() # Now do a forward and backward pass for each gradient accumulation step. new_states = {} + self.batch_size_optimizer.focus(step.get_optimizers()[-1]) for m in range(self.batch_factor): - ns = s.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor)) + ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor)) for k, v in ns.items(): if k not in new_states.keys(): new_states[k] = [v] @@ -281,54 +290,17 @@ class ExtensibleTrainer(BaseModel): # Push the detached new state tensors into the state map for use with the next step. for k, v in new_states.items(): - # State is immutable to reduce complexity. Overwriting existing state keys is not supported. - class OverwrittenStateError(Exception): - def __init__(self, k, keys): - super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered ' - f'immutable and keys should not be overwritten. Current keys: {keys}') if k in state.keys(): raise OverwrittenStateError(k, list(state.keys())) state[k] = v - if train_step and optimize: - # And finally perform optimization. - [e.before_optimize(state) for e in self.experiments] - s.do_step(step) - - if s.nan_counter > 10: - if self.auto_recover is None: - print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.") - self.save(step) - self.save_training_state({'iter': step}) - raise ArithmeticError - else: - print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.") - for k, ps in self.save_history.keys(): - if len(ps) < self.auto_recover: - print("Belay that - not enough saves were recorded. Failing instead.") - raise ArithmeticError - if k == '__state__': - self.resume_training(torch.load(ps[-self.auto_recover])) - else: - if k in self.networks.keys(): # This isn't always the case, for example for EMAs. - self.load_network(ps[-self.auto_recover], self.networks[k], strict=True) - self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True) - - # Call into custom step hooks as well as update EMA params. - for name, net in self.networks.items(): - if hasattr(net, "custom_optimizer_step"): - net.custom_optimizer_step(step) - ema_params = self.emas[name].parameters() - net_params = net.parameters() - for ep, np in zip(ema_params, net_params): - if self.ema_on_cpu: - np = np.cpu() - ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate) - [e.after_optimize(state) for e in self.experiments] + # (Maybe) perform a step. + if train_step and optimize and self.batch_size_optimizer.should_step(it): + self.consume_gradients(state, step, it) # Record visual outputs for usage in debugging and testing. - if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0: + if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and it % self.opt['logger']['visual_debug_rate'] == 0: def fix_image(img): if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False): img = img.unsqueeze(dim=1) @@ -351,17 +323,54 @@ class ExtensibleTrainer(BaseModel): for rvi in self.opt['logger']['recurrent_visual_indices']: rdbgv = fix_image(dbgv[:, rvi]) os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) - utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i))) + utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (it, rvi, i))) else: dbgv = fix_image(dbgv) os.makedirs(os.path.join(sample_save_path, v), exist_ok=True) - utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i))) + utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (it, i))) # Some models have their own specific visual debug routines. for net_name, net in self.networks.items(): if hasattr(net.module, "visual_dbg"): model_vdbg_dir = os.path.join(sample_save_path, net_name) os.makedirs(model_vdbg_dir, exist_ok=True) - net.module.visual_dbg(step, model_vdbg_dir) + net.module.visual_dbg(it, model_vdbg_dir) + + + def consume_gradients(self, state, step, it): + [e.before_optimize(state) for e in self.experiments] + step.do_step(it) + + if step.nan_counter > 10: + if self.auto_recover is None: + print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.") + self.save(it) + self.save_training_state({'iter': it}) + raise ArithmeticError + else: + print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.") + for k, ps in self.save_history.keys(): + if len(ps) < self.auto_recover: + print("Belay that - not enough saves were recorded. Failing instead.") + raise ArithmeticError + if k == '__state__': + self.resume_training(torch.load(ps[-self.auto_recover])) + else: + if k in self.networks.keys(): # This isn't always the case, for example for EMAs. + self.load_network(ps[-self.auto_recover], self.networks[k], strict=True) + self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True) + + # Call into custom step hooks as well as update EMA params. + for name, net in self.networks.items(): + if hasattr(net, "custom_optimizer_step"): + net.custom_optimizer_step(it) + ema_params = self.emas[name].parameters() + net_params = net.parameters() + for ep, np in zip(ema_params, net_params): + if self.ema_on_cpu: + np = np.cpu() + ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate) + [e.after_optimize(state) for e in self.experiments] + def test(self): for net in self.netsG.values(): @@ -416,6 +425,9 @@ class ExtensibleTrainer(BaseModel): for o in self.optimizers: for pgi, pg in enumerate(o.param_groups): log['learning_rate_%s_%i' % (o._config['network'], pgi)] = pg['lr'] + + # The batch size optimizer also outputs loggable data. + log.update(self.batch_size_optimizer.get_statistics()) return log def get_current_visuals(self, need_GT=True): diff --git a/codes/trainer/batch_size_optimizer.py b/codes/trainer/batch_size_optimizer.py new file mode 100644 index 00000000..c2695935 --- /dev/null +++ b/codes/trainer/batch_size_optimizer.py @@ -0,0 +1,66 @@ +import torch + +from utils.util import opt_get + + +def create_batch_size_optimizer(opt_train): + if 'batch_size_optimizer' in opt_train.keys(): + if opt_train['batch_size_optimizer']['type'] == 'gradient_direction': + return GradientDirectionOptimizer(opt_train) + return MegabatchBatchSizeOptimizer(opt_train) + + +# Base class for BatchSizeOptimizers. +class BatchSizeOptimizer: + def focus(self, optimizer): + pass + + def should_step(self, it): + raise NotImplementedError + + def get_statistics(self): + return {} + + +# BatchSizeOptimizer that just steps every megabatch. +class MegabatchBatchSizeOptimizer(BatchSizeOptimizer): + def __init__(self, opt_train): + pass + + def should_step(self, it): + return True + + +# BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step. +# Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf +class GradientDirectionOptimizer(BatchSizeOptimizer): + def __init__(self, opt_train): + self.mbf = opt_train['mega_batch_factor'] + self.opt = opt_train['batch_size_optimizer'] + self.max_full_batches = opt_get(self.opt, ['max_full_batches'], 10) + self.parameters_to_poll = opt_get(self.opt, ['poll_parameters'], 8) + self.recalculate_directions_every = opt_get(self.opt, ['recalculate_directions_steps'], 1) + self.last_number_iterations = 0 + + def vector_angle(self, v1, v2): + with torch.no_grad(): + v1 = v1.flatten() + v2 = v2.flatten() + v1_norm = (v1 ** 2).sum().sqrt() + v2_norm = (v2 ** 2).sum().sqrt() + angle = torch.arccos((v1 * v2) / (v1_norm * v2_norm)) + return angle + + def focus(self, optimizer): + optimizer._gradient_direction_optimizer_params = [] + optimizer._gradient_direction_optimizer_prior_directions = [] + optimizer._gradient_direction_optimizer_prior_grads = [] + optimizer._gradient_direction_optimizer_direction_change_magnitudes = [] + optimizer._gradient_direction_optimizer_step = 0 + self.current_opt = optimizer + + def should_step(self, it): + self.last_number_iterations += 1 + + def get_statistics(self): + return {"last_number_iterations_before_step": self.last_number_iterations}