Add batch_size_optimizer support
This commit is contained in:
parent
9e9ae328f2
commit
18938248e4
|
@ -9,6 +9,7 @@ import torch.nn as nn
|
|||
import trainer.lr_scheduler as lr_scheduler
|
||||
import trainer.networks as networks
|
||||
from trainer.base_model import BaseModel
|
||||
from trainer.batch_size_optimizer import create_batch_size_optimizer
|
||||
from trainer.inject import create_injector
|
||||
from trainer.steps import ConfigurableStep
|
||||
from trainer.experiments.experiments import get_experiment_for_name
|
||||
|
@ -20,6 +21,12 @@ from utils.util import opt_get, denormalize
|
|||
logger = logging.getLogger('base')
|
||||
|
||||
|
||||
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
|
||||
class OverwrittenStateError(Exception):
|
||||
def __init__(self, k, keys):
|
||||
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
|
||||
f'immutable and keys should not be overwritten. Current keys: {keys}')
|
||||
|
||||
class ExtensibleTrainer(BaseModel):
|
||||
def __init__(self, opt, cached_networks={}):
|
||||
super(ExtensibleTrainer, self).__init__(opt)
|
||||
|
@ -50,6 +57,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
|
||||
self.checkpointing_cache = opt['checkpointing_enabled']
|
||||
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
||||
self.batch_size_optimizer = create_batch_size_optimizer(train_opt)
|
||||
|
||||
self.netsG = {}
|
||||
self.netsD = {}
|
||||
|
@ -218,27 +226,27 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
|
||||
|
||||
|
||||
def optimize_parameters(self, step, optimize=True):
|
||||
def optimize_parameters(self, it, optimize=True):
|
||||
# Some models need to make parametric adjustments per-step. Do that here.
|
||||
for net in self.networks.values():
|
||||
if hasattr(net.module, "update_for_step"):
|
||||
net.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
||||
net.module.update_for_step(it, os.path.join(self.opt['path']['models'], ".."))
|
||||
|
||||
# Iterate through the steps, performing them one at a time.
|
||||
state = self.dstate
|
||||
for step_num, s in enumerate(self.steps):
|
||||
for step_num, step in enumerate(self.steps):
|
||||
train_step = True
|
||||
# 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps.
|
||||
# Note that the injection points for the step might still be required, so address this by setting train_step=False
|
||||
if 'every' in s.step_opt.keys() and step % s.step_opt['every'] != 0:
|
||||
if 'every' in step.step_opt.keys() and it % step.step_opt['every'] != 0:
|
||||
train_step = False
|
||||
# Steps can opt out of early (or late) training, make sure that happens here.
|
||||
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
|
||||
if 'after' in step.step_opt.keys() and it < step.step_opt['after'] or 'before' in step.step_opt.keys() and it > step.step_opt['before']:
|
||||
continue
|
||||
# Steps can choose to not execute if a state key is missing.
|
||||
if 'requires' in s.step_opt.keys():
|
||||
if 'requires' in step.step_opt.keys():
|
||||
requirements_met = True
|
||||
for requirement in s.step_opt['requires']:
|
||||
for requirement in step.step_opt['requires']:
|
||||
if requirement not in state.keys():
|
||||
requirements_met = False
|
||||
if not requirements_met:
|
||||
|
@ -246,17 +254,17 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
if train_step:
|
||||
# Only set requires_grad=True for the network being trained.
|
||||
nets_to_train = s.get_networks_trained()
|
||||
nets_to_train = step.get_networks_trained()
|
||||
enabled = 0
|
||||
for name, net in self.networks.items():
|
||||
net_enabled = name in nets_to_train
|
||||
if net_enabled:
|
||||
enabled += 1
|
||||
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
|
||||
if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']:
|
||||
if 'after' in self.opt['networks'][name].keys() and it < self.opt['networks'][name]['after']:
|
||||
net_enabled = False
|
||||
for p in net.parameters():
|
||||
do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and step < p.DO_NOT_TRAIN_UNTIL)
|
||||
do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and it < p.DO_NOT_TRAIN_UNTIL)
|
||||
if p.dtype != torch.int64 and p.dtype != torch.bool and not do_not_train_flag:
|
||||
p.requires_grad = net_enabled
|
||||
else:
|
||||
|
@ -266,13 +274,14 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Update experiments
|
||||
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
|
||||
|
||||
for o in s.get_optimizers():
|
||||
for o in step.get_optimizers():
|
||||
o.zero_grad()
|
||||
|
||||
# Now do a forward and backward pass for each gradient accumulation step.
|
||||
new_states = {}
|
||||
self.batch_size_optimizer.focus(step.get_optimizers()[-1])
|
||||
for m in range(self.batch_factor):
|
||||
ns = s.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
|
||||
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
|
||||
for k, v in ns.items():
|
||||
if k not in new_states.keys():
|
||||
new_states[k] = [v]
|
||||
|
@ -281,54 +290,17 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
# Push the detached new state tensors into the state map for use with the next step.
|
||||
for k, v in new_states.items():
|
||||
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
|
||||
class OverwrittenStateError(Exception):
|
||||
def __init__(self, k, keys):
|
||||
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
|
||||
f'immutable and keys should not be overwritten. Current keys: {keys}')
|
||||
if k in state.keys():
|
||||
raise OverwrittenStateError(k, list(state.keys()))
|
||||
state[k] = v
|
||||
|
||||
if train_step and optimize:
|
||||
# And finally perform optimization.
|
||||
[e.before_optimize(state) for e in self.experiments]
|
||||
s.do_step(step)
|
||||
|
||||
if s.nan_counter > 10:
|
||||
if self.auto_recover is None:
|
||||
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
|
||||
self.save(step)
|
||||
self.save_training_state({'iter': step})
|
||||
raise ArithmeticError
|
||||
else:
|
||||
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
|
||||
for k, ps in self.save_history.keys():
|
||||
if len(ps) < self.auto_recover:
|
||||
print("Belay that - not enough saves were recorded. Failing instead.")
|
||||
raise ArithmeticError
|
||||
if k == '__state__':
|
||||
self.resume_training(torch.load(ps[-self.auto_recover]))
|
||||
else:
|
||||
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
|
||||
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
|
||||
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
|
||||
|
||||
# Call into custom step hooks as well as update EMA params.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net, "custom_optimizer_step"):
|
||||
net.custom_optimizer_step(step)
|
||||
ema_params = self.emas[name].parameters()
|
||||
net_params = net.parameters()
|
||||
for ep, np in zip(ema_params, net_params):
|
||||
if self.ema_on_cpu:
|
||||
np = np.cpu()
|
||||
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
||||
[e.after_optimize(state) for e in self.experiments]
|
||||
# (Maybe) perform a step.
|
||||
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
||||
self.consume_gradients(state, step, it)
|
||||
|
||||
|
||||
# Record visual outputs for usage in debugging and testing.
|
||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and it % self.opt['logger']['visual_debug_rate'] == 0:
|
||||
def fix_image(img):
|
||||
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
|
||||
img = img.unsqueeze(dim=1)
|
||||
|
@ -351,17 +323,54 @@ class ExtensibleTrainer(BaseModel):
|
|||
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
||||
rdbgv = fix_image(dbgv[:, rvi])
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
||||
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (it, rvi, i)))
|
||||
else:
|
||||
dbgv = fix_image(dbgv)
|
||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
||||
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (it, i)))
|
||||
# Some models have their own specific visual debug routines.
|
||||
for net_name, net in self.networks.items():
|
||||
if hasattr(net.module, "visual_dbg"):
|
||||
model_vdbg_dir = os.path.join(sample_save_path, net_name)
|
||||
os.makedirs(model_vdbg_dir, exist_ok=True)
|
||||
net.module.visual_dbg(step, model_vdbg_dir)
|
||||
net.module.visual_dbg(it, model_vdbg_dir)
|
||||
|
||||
|
||||
def consume_gradients(self, state, step, it):
|
||||
[e.before_optimize(state) for e in self.experiments]
|
||||
step.do_step(it)
|
||||
|
||||
if step.nan_counter > 10:
|
||||
if self.auto_recover is None:
|
||||
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
|
||||
self.save(it)
|
||||
self.save_training_state({'iter': it})
|
||||
raise ArithmeticError
|
||||
else:
|
||||
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
|
||||
for k, ps in self.save_history.keys():
|
||||
if len(ps) < self.auto_recover:
|
||||
print("Belay that - not enough saves were recorded. Failing instead.")
|
||||
raise ArithmeticError
|
||||
if k == '__state__':
|
||||
self.resume_training(torch.load(ps[-self.auto_recover]))
|
||||
else:
|
||||
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
|
||||
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
|
||||
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
|
||||
|
||||
# Call into custom step hooks as well as update EMA params.
|
||||
for name, net in self.networks.items():
|
||||
if hasattr(net, "custom_optimizer_step"):
|
||||
net.custom_optimizer_step(it)
|
||||
ema_params = self.emas[name].parameters()
|
||||
net_params = net.parameters()
|
||||
for ep, np in zip(ema_params, net_params):
|
||||
if self.ema_on_cpu:
|
||||
np = np.cpu()
|
||||
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
||||
[e.after_optimize(state) for e in self.experiments]
|
||||
|
||||
|
||||
def test(self):
|
||||
for net in self.netsG.values():
|
||||
|
@ -416,6 +425,9 @@ class ExtensibleTrainer(BaseModel):
|
|||
for o in self.optimizers:
|
||||
for pgi, pg in enumerate(o.param_groups):
|
||||
log['learning_rate_%s_%i' % (o._config['network'], pgi)] = pg['lr']
|
||||
|
||||
# The batch size optimizer also outputs loggable data.
|
||||
log.update(self.batch_size_optimizer.get_statistics())
|
||||
return log
|
||||
|
||||
def get_current_visuals(self, need_GT=True):
|
||||
|
|
66
codes/trainer/batch_size_optimizer.py
Normal file
66
codes/trainer/batch_size_optimizer.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import torch
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def create_batch_size_optimizer(opt_train):
|
||||
if 'batch_size_optimizer' in opt_train.keys():
|
||||
if opt_train['batch_size_optimizer']['type'] == 'gradient_direction':
|
||||
return GradientDirectionOptimizer(opt_train)
|
||||
return MegabatchBatchSizeOptimizer(opt_train)
|
||||
|
||||
|
||||
# Base class for BatchSizeOptimizers.
|
||||
class BatchSizeOptimizer:
|
||||
def focus(self, optimizer):
|
||||
pass
|
||||
|
||||
def should_step(self, it):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_statistics(self):
|
||||
return {}
|
||||
|
||||
|
||||
# BatchSizeOptimizer that just steps every megabatch.
|
||||
class MegabatchBatchSizeOptimizer(BatchSizeOptimizer):
|
||||
def __init__(self, opt_train):
|
||||
pass
|
||||
|
||||
def should_step(self, it):
|
||||
return True
|
||||
|
||||
|
||||
# BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step.
|
||||
# Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf
|
||||
class GradientDirectionOptimizer(BatchSizeOptimizer):
|
||||
def __init__(self, opt_train):
|
||||
self.mbf = opt_train['mega_batch_factor']
|
||||
self.opt = opt_train['batch_size_optimizer']
|
||||
self.max_full_batches = opt_get(self.opt, ['max_full_batches'], 10)
|
||||
self.parameters_to_poll = opt_get(self.opt, ['poll_parameters'], 8)
|
||||
self.recalculate_directions_every = opt_get(self.opt, ['recalculate_directions_steps'], 1)
|
||||
self.last_number_iterations = 0
|
||||
|
||||
def vector_angle(self, v1, v2):
|
||||
with torch.no_grad():
|
||||
v1 = v1.flatten()
|
||||
v2 = v2.flatten()
|
||||
v1_norm = (v1 ** 2).sum().sqrt()
|
||||
v2_norm = (v2 ** 2).sum().sqrt()
|
||||
angle = torch.arccos((v1 * v2) / (v1_norm * v2_norm))
|
||||
return angle
|
||||
|
||||
def focus(self, optimizer):
|
||||
optimizer._gradient_direction_optimizer_params = []
|
||||
optimizer._gradient_direction_optimizer_prior_directions = []
|
||||
optimizer._gradient_direction_optimizer_prior_grads = []
|
||||
optimizer._gradient_direction_optimizer_direction_change_magnitudes = []
|
||||
optimizer._gradient_direction_optimizer_step = 0
|
||||
self.current_opt = optimizer
|
||||
|
||||
def should_step(self, it):
|
||||
self.last_number_iterations += 1
|
||||
|
||||
def get_statistics(self):
|
||||
return {"last_number_iterations_before_step": self.last_number_iterations}
|
Loading…
Reference in New Issue
Block a user