Add batch_size_optimizer support
This commit is contained in:
parent
9e9ae328f2
commit
18938248e4
|
@ -9,6 +9,7 @@ import torch.nn as nn
|
||||||
import trainer.lr_scheduler as lr_scheduler
|
import trainer.lr_scheduler as lr_scheduler
|
||||||
import trainer.networks as networks
|
import trainer.networks as networks
|
||||||
from trainer.base_model import BaseModel
|
from trainer.base_model import BaseModel
|
||||||
|
from trainer.batch_size_optimizer import create_batch_size_optimizer
|
||||||
from trainer.inject import create_injector
|
from trainer.inject import create_injector
|
||||||
from trainer.steps import ConfigurableStep
|
from trainer.steps import ConfigurableStep
|
||||||
from trainer.experiments.experiments import get_experiment_for_name
|
from trainer.experiments.experiments import get_experiment_for_name
|
||||||
|
@ -20,6 +21,12 @@ from utils.util import opt_get, denormalize
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
|
||||||
|
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
|
||||||
|
class OverwrittenStateError(Exception):
|
||||||
|
def __init__(self, k, keys):
|
||||||
|
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
|
||||||
|
f'immutable and keys should not be overwritten. Current keys: {keys}')
|
||||||
|
|
||||||
class ExtensibleTrainer(BaseModel):
|
class ExtensibleTrainer(BaseModel):
|
||||||
def __init__(self, opt, cached_networks={}):
|
def __init__(self, opt, cached_networks={}):
|
||||||
super(ExtensibleTrainer, self).__init__(opt)
|
super(ExtensibleTrainer, self).__init__(opt)
|
||||||
|
@ -50,6 +57,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
|
self.ema_on_cpu = opt_get(train_opt, ['ema_on_cpu'], False)
|
||||||
self.checkpointing_cache = opt['checkpointing_enabled']
|
self.checkpointing_cache = opt['checkpointing_enabled']
|
||||||
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
self.auto_recover = opt_get(opt, ['automatically_recover_nan_by_reverting_n_saves'], None)
|
||||||
|
self.batch_size_optimizer = create_batch_size_optimizer(train_opt)
|
||||||
|
|
||||||
self.netsG = {}
|
self.netsG = {}
|
||||||
self.netsD = {}
|
self.netsD = {}
|
||||||
|
@ -218,27 +226,27 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
|
self.dstate[k][c] = self.dstate[k][c][:, :, :, :maxlen]
|
||||||
|
|
||||||
|
|
||||||
def optimize_parameters(self, step, optimize=True):
|
def optimize_parameters(self, it, optimize=True):
|
||||||
# Some models need to make parametric adjustments per-step. Do that here.
|
# Some models need to make parametric adjustments per-step. Do that here.
|
||||||
for net in self.networks.values():
|
for net in self.networks.values():
|
||||||
if hasattr(net.module, "update_for_step"):
|
if hasattr(net.module, "update_for_step"):
|
||||||
net.module.update_for_step(step, os.path.join(self.opt['path']['models'], ".."))
|
net.module.update_for_step(it, os.path.join(self.opt['path']['models'], ".."))
|
||||||
|
|
||||||
# Iterate through the steps, performing them one at a time.
|
# Iterate through the steps, performing them one at a time.
|
||||||
state = self.dstate
|
state = self.dstate
|
||||||
for step_num, s in enumerate(self.steps):
|
for step_num, step in enumerate(self.steps):
|
||||||
train_step = True
|
train_step = True
|
||||||
# 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps.
|
# 'every' is used to denote steps that should only occur at a certain integer factor rate. e.g. '2' occurs every 2 steps.
|
||||||
# Note that the injection points for the step might still be required, so address this by setting train_step=False
|
# Note that the injection points for the step might still be required, so address this by setting train_step=False
|
||||||
if 'every' in s.step_opt.keys() and step % s.step_opt['every'] != 0:
|
if 'every' in step.step_opt.keys() and it % step.step_opt['every'] != 0:
|
||||||
train_step = False
|
train_step = False
|
||||||
# Steps can opt out of early (or late) training, make sure that happens here.
|
# Steps can opt out of early (or late) training, make sure that happens here.
|
||||||
if 'after' in s.step_opt.keys() and step < s.step_opt['after'] or 'before' in s.step_opt.keys() and step > s.step_opt['before']:
|
if 'after' in step.step_opt.keys() and it < step.step_opt['after'] or 'before' in step.step_opt.keys() and it > step.step_opt['before']:
|
||||||
continue
|
continue
|
||||||
# Steps can choose to not execute if a state key is missing.
|
# Steps can choose to not execute if a state key is missing.
|
||||||
if 'requires' in s.step_opt.keys():
|
if 'requires' in step.step_opt.keys():
|
||||||
requirements_met = True
|
requirements_met = True
|
||||||
for requirement in s.step_opt['requires']:
|
for requirement in step.step_opt['requires']:
|
||||||
if requirement not in state.keys():
|
if requirement not in state.keys():
|
||||||
requirements_met = False
|
requirements_met = False
|
||||||
if not requirements_met:
|
if not requirements_met:
|
||||||
|
@ -246,17 +254,17 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
if train_step:
|
if train_step:
|
||||||
# Only set requires_grad=True for the network being trained.
|
# Only set requires_grad=True for the network being trained.
|
||||||
nets_to_train = s.get_networks_trained()
|
nets_to_train = step.get_networks_trained()
|
||||||
enabled = 0
|
enabled = 0
|
||||||
for name, net in self.networks.items():
|
for name, net in self.networks.items():
|
||||||
net_enabled = name in nets_to_train
|
net_enabled = name in nets_to_train
|
||||||
if net_enabled:
|
if net_enabled:
|
||||||
enabled += 1
|
enabled += 1
|
||||||
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
|
# Networks can opt out of training before a certain iteration by declaring 'after' in their definition.
|
||||||
if 'after' in self.opt['networks'][name].keys() and step < self.opt['networks'][name]['after']:
|
if 'after' in self.opt['networks'][name].keys() and it < self.opt['networks'][name]['after']:
|
||||||
net_enabled = False
|
net_enabled = False
|
||||||
for p in net.parameters():
|
for p in net.parameters():
|
||||||
do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and step < p.DO_NOT_TRAIN_UNTIL)
|
do_not_train_flag = hasattr(p, "DO_NOT_TRAIN") or (hasattr(p, "DO_NOT_TRAIN_UNTIL") and it < p.DO_NOT_TRAIN_UNTIL)
|
||||||
if p.dtype != torch.int64 and p.dtype != torch.bool and not do_not_train_flag:
|
if p.dtype != torch.int64 and p.dtype != torch.bool and not do_not_train_flag:
|
||||||
p.requires_grad = net_enabled
|
p.requires_grad = net_enabled
|
||||||
else:
|
else:
|
||||||
|
@ -266,13 +274,14 @@ class ExtensibleTrainer(BaseModel):
|
||||||
# Update experiments
|
# Update experiments
|
||||||
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
|
[e.before_step(self.opt, self.step_names[step_num], self.env, nets_to_train, state) for e in self.experiments]
|
||||||
|
|
||||||
for o in s.get_optimizers():
|
for o in step.get_optimizers():
|
||||||
o.zero_grad()
|
o.zero_grad()
|
||||||
|
|
||||||
# Now do a forward and backward pass for each gradient accumulation step.
|
# Now do a forward and backward pass for each gradient accumulation step.
|
||||||
new_states = {}
|
new_states = {}
|
||||||
|
self.batch_size_optimizer.focus(step.get_optimizers()[-1])
|
||||||
for m in range(self.batch_factor):
|
for m in range(self.batch_factor):
|
||||||
ns = s.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
|
ns = step.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
|
||||||
for k, v in ns.items():
|
for k, v in ns.items():
|
||||||
if k not in new_states.keys():
|
if k not in new_states.keys():
|
||||||
new_states[k] = [v]
|
new_states[k] = [v]
|
||||||
|
@ -281,54 +290,17 @@ class ExtensibleTrainer(BaseModel):
|
||||||
|
|
||||||
# Push the detached new state tensors into the state map for use with the next step.
|
# Push the detached new state tensors into the state map for use with the next step.
|
||||||
for k, v in new_states.items():
|
for k, v in new_states.items():
|
||||||
# State is immutable to reduce complexity. Overwriting existing state keys is not supported.
|
|
||||||
class OverwrittenStateError(Exception):
|
|
||||||
def __init__(self, k, keys):
|
|
||||||
super().__init__(f'Attempted to overwrite state key: {k}. The state should be considered '
|
|
||||||
f'immutable and keys should not be overwritten. Current keys: {keys}')
|
|
||||||
if k in state.keys():
|
if k in state.keys():
|
||||||
raise OverwrittenStateError(k, list(state.keys()))
|
raise OverwrittenStateError(k, list(state.keys()))
|
||||||
state[k] = v
|
state[k] = v
|
||||||
|
|
||||||
if train_step and optimize:
|
# (Maybe) perform a step.
|
||||||
# And finally perform optimization.
|
if train_step and optimize and self.batch_size_optimizer.should_step(it):
|
||||||
[e.before_optimize(state) for e in self.experiments]
|
self.consume_gradients(state, step, it)
|
||||||
s.do_step(step)
|
|
||||||
|
|
||||||
if s.nan_counter > 10:
|
|
||||||
if self.auto_recover is None:
|
|
||||||
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
|
|
||||||
self.save(step)
|
|
||||||
self.save_training_state({'iter': step})
|
|
||||||
raise ArithmeticError
|
|
||||||
else:
|
|
||||||
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
|
|
||||||
for k, ps in self.save_history.keys():
|
|
||||||
if len(ps) < self.auto_recover:
|
|
||||||
print("Belay that - not enough saves were recorded. Failing instead.")
|
|
||||||
raise ArithmeticError
|
|
||||||
if k == '__state__':
|
|
||||||
self.resume_training(torch.load(ps[-self.auto_recover]))
|
|
||||||
else:
|
|
||||||
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
|
|
||||||
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
|
|
||||||
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
|
|
||||||
|
|
||||||
# Call into custom step hooks as well as update EMA params.
|
|
||||||
for name, net in self.networks.items():
|
|
||||||
if hasattr(net, "custom_optimizer_step"):
|
|
||||||
net.custom_optimizer_step(step)
|
|
||||||
ema_params = self.emas[name].parameters()
|
|
||||||
net_params = net.parameters()
|
|
||||||
for ep, np in zip(ema_params, net_params):
|
|
||||||
if self.ema_on_cpu:
|
|
||||||
np = np.cpu()
|
|
||||||
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
|
||||||
[e.after_optimize(state) for e in self.experiments]
|
|
||||||
|
|
||||||
|
|
||||||
# Record visual outputs for usage in debugging and testing.
|
# Record visual outputs for usage in debugging and testing.
|
||||||
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and step % self.opt['logger']['visual_debug_rate'] == 0:
|
if 'visuals' in self.opt['logger'].keys() and self.rank <= 0 and it % self.opt['logger']['visual_debug_rate'] == 0:
|
||||||
def fix_image(img):
|
def fix_image(img):
|
||||||
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
|
if opt_get(self.opt, ['logger', 'is_mel_spectrogram'], False):
|
||||||
img = img.unsqueeze(dim=1)
|
img = img.unsqueeze(dim=1)
|
||||||
|
@ -351,17 +323,54 @@ class ExtensibleTrainer(BaseModel):
|
||||||
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
for rvi in self.opt['logger']['recurrent_visual_indices']:
|
||||||
rdbgv = fix_image(dbgv[:, rvi])
|
rdbgv = fix_image(dbgv[:, rvi])
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (step, rvi, i)))
|
utils.save_image(rdbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i_%02i.png" % (it, rvi, i)))
|
||||||
else:
|
else:
|
||||||
dbgv = fix_image(dbgv)
|
dbgv = fix_image(dbgv)
|
||||||
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
os.makedirs(os.path.join(sample_save_path, v), exist_ok=True)
|
||||||
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (step, i)))
|
utils.save_image(dbgv.float(), os.path.join(sample_save_path, v, "%05i_%02i.png" % (it, i)))
|
||||||
# Some models have their own specific visual debug routines.
|
# Some models have their own specific visual debug routines.
|
||||||
for net_name, net in self.networks.items():
|
for net_name, net in self.networks.items():
|
||||||
if hasattr(net.module, "visual_dbg"):
|
if hasattr(net.module, "visual_dbg"):
|
||||||
model_vdbg_dir = os.path.join(sample_save_path, net_name)
|
model_vdbg_dir = os.path.join(sample_save_path, net_name)
|
||||||
os.makedirs(model_vdbg_dir, exist_ok=True)
|
os.makedirs(model_vdbg_dir, exist_ok=True)
|
||||||
net.module.visual_dbg(step, model_vdbg_dir)
|
net.module.visual_dbg(it, model_vdbg_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def consume_gradients(self, state, step, it):
|
||||||
|
[e.before_optimize(state) for e in self.experiments]
|
||||||
|
step.do_step(it)
|
||||||
|
|
||||||
|
if step.nan_counter > 10:
|
||||||
|
if self.auto_recover is None:
|
||||||
|
print("Detected NaN grads more than 10 steps in a row. Saving model weights and aborting.")
|
||||||
|
self.save(it)
|
||||||
|
self.save_training_state({'iter': it})
|
||||||
|
raise ArithmeticError
|
||||||
|
else:
|
||||||
|
print(f"!!!!!!!!Detected NaN grads more than 10 steps in a row. Restoring to a state {self.auto_recover} saves ago.")
|
||||||
|
for k, ps in self.save_history.keys():
|
||||||
|
if len(ps) < self.auto_recover:
|
||||||
|
print("Belay that - not enough saves were recorded. Failing instead.")
|
||||||
|
raise ArithmeticError
|
||||||
|
if k == '__state__':
|
||||||
|
self.resume_training(torch.load(ps[-self.auto_recover]))
|
||||||
|
else:
|
||||||
|
if k in self.networks.keys(): # This isn't always the case, for example for EMAs.
|
||||||
|
self.load_network(ps[-self.auto_recover], self.networks[k], strict=True)
|
||||||
|
self.load_network(self.save_history[f'{k}_ema'][-self.auto_recover], self.emas[k], strict=True)
|
||||||
|
|
||||||
|
# Call into custom step hooks as well as update EMA params.
|
||||||
|
for name, net in self.networks.items():
|
||||||
|
if hasattr(net, "custom_optimizer_step"):
|
||||||
|
net.custom_optimizer_step(it)
|
||||||
|
ema_params = self.emas[name].parameters()
|
||||||
|
net_params = net.parameters()
|
||||||
|
for ep, np in zip(ema_params, net_params):
|
||||||
|
if self.ema_on_cpu:
|
||||||
|
np = np.cpu()
|
||||||
|
ep.detach().mul_(self.ema_rate).add_(np, alpha=1 - self.ema_rate)
|
||||||
|
[e.after_optimize(state) for e in self.experiments]
|
||||||
|
|
||||||
|
|
||||||
def test(self):
|
def test(self):
|
||||||
for net in self.netsG.values():
|
for net in self.netsG.values():
|
||||||
|
@ -416,6 +425,9 @@ class ExtensibleTrainer(BaseModel):
|
||||||
for o in self.optimizers:
|
for o in self.optimizers:
|
||||||
for pgi, pg in enumerate(o.param_groups):
|
for pgi, pg in enumerate(o.param_groups):
|
||||||
log['learning_rate_%s_%i' % (o._config['network'], pgi)] = pg['lr']
|
log['learning_rate_%s_%i' % (o._config['network'], pgi)] = pg['lr']
|
||||||
|
|
||||||
|
# The batch size optimizer also outputs loggable data.
|
||||||
|
log.update(self.batch_size_optimizer.get_statistics())
|
||||||
return log
|
return log
|
||||||
|
|
||||||
def get_current_visuals(self, need_GT=True):
|
def get_current_visuals(self, need_GT=True):
|
||||||
|
|
66
codes/trainer/batch_size_optimizer.py
Normal file
66
codes/trainer/batch_size_optimizer.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def create_batch_size_optimizer(opt_train):
|
||||||
|
if 'batch_size_optimizer' in opt_train.keys():
|
||||||
|
if opt_train['batch_size_optimizer']['type'] == 'gradient_direction':
|
||||||
|
return GradientDirectionOptimizer(opt_train)
|
||||||
|
return MegabatchBatchSizeOptimizer(opt_train)
|
||||||
|
|
||||||
|
|
||||||
|
# Base class for BatchSizeOptimizers.
|
||||||
|
class BatchSizeOptimizer:
|
||||||
|
def focus(self, optimizer):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def should_step(self, it):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_statistics(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
# BatchSizeOptimizer that just steps every megabatch.
|
||||||
|
class MegabatchBatchSizeOptimizer(BatchSizeOptimizer):
|
||||||
|
def __init__(self, opt_train):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def should_step(self, it):
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# BatchSizeOptimizer that uses the gradient direction of a few parameters to determine when to step.
|
||||||
|
# Very similar to what is described in https://aclanthology.org/2020.acl-main.323.pdf
|
||||||
|
class GradientDirectionOptimizer(BatchSizeOptimizer):
|
||||||
|
def __init__(self, opt_train):
|
||||||
|
self.mbf = opt_train['mega_batch_factor']
|
||||||
|
self.opt = opt_train['batch_size_optimizer']
|
||||||
|
self.max_full_batches = opt_get(self.opt, ['max_full_batches'], 10)
|
||||||
|
self.parameters_to_poll = opt_get(self.opt, ['poll_parameters'], 8)
|
||||||
|
self.recalculate_directions_every = opt_get(self.opt, ['recalculate_directions_steps'], 1)
|
||||||
|
self.last_number_iterations = 0
|
||||||
|
|
||||||
|
def vector_angle(self, v1, v2):
|
||||||
|
with torch.no_grad():
|
||||||
|
v1 = v1.flatten()
|
||||||
|
v2 = v2.flatten()
|
||||||
|
v1_norm = (v1 ** 2).sum().sqrt()
|
||||||
|
v2_norm = (v2 ** 2).sum().sqrt()
|
||||||
|
angle = torch.arccos((v1 * v2) / (v1_norm * v2_norm))
|
||||||
|
return angle
|
||||||
|
|
||||||
|
def focus(self, optimizer):
|
||||||
|
optimizer._gradient_direction_optimizer_params = []
|
||||||
|
optimizer._gradient_direction_optimizer_prior_directions = []
|
||||||
|
optimizer._gradient_direction_optimizer_prior_grads = []
|
||||||
|
optimizer._gradient_direction_optimizer_direction_change_magnitudes = []
|
||||||
|
optimizer._gradient_direction_optimizer_step = 0
|
||||||
|
self.current_opt = optimizer
|
||||||
|
|
||||||
|
def should_step(self, it):
|
||||||
|
self.last_number_iterations += 1
|
||||||
|
|
||||||
|
def get_statistics(self):
|
||||||
|
return {"last_number_iterations_before_step": self.last_number_iterations}
|
Loading…
Reference in New Issue
Block a user