From f3c59c3e7eb165a1c9b146271c053fe99e0429f7 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 1 Mar 2024 20:18:43 -0600 Subject: [PATCH] cleaner replacement code (because I realized BitNet had an implementation for it too), added calculating gradient norm and performing gradient clipping in local trainer (non-deepspeed) --- vall_e/engines/base.py | 13 ++++++++++++- vall_e/models/ar_nar.py | 3 ++- vall_e/utils/wrapper.py | 20 ++++++++------------ 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 70a97ec..a0bb626 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -102,6 +102,10 @@ class Engine(): @property def gradient_accumulation_steps(self): return cfg.hyperparameters.gradient_accumulation_steps + + @property + def gradient_clipping(self): + return cfg.hyperparameters.gradient_clipping def gather_attribute(self, *args, **kwargs): return gather_attribute(self.module, *args, **kwargs) @@ -186,10 +190,17 @@ class Engine(): self.global_samples += self.batch_size if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0: + torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping) + self.global_steps += 1 self.optimizer.step() self.optimizer.zero_grad() + self._get_grad_norm() + + def _get_grad_norm(self): + self._global_grad_norm = torch.cat([ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]).norm().item() + def get_lr(self): lrs = [] for param_group in self.optimizer.param_groups: @@ -207,7 +218,7 @@ class Engine(): param_group['lr'] = lr def get_global_grad_norm(self): - return 0.0 + return self._global_grad_norm def traverse(self, *args, **kwargs): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ef4a229..dc3b202 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -336,6 +336,7 @@ def example_usage(): proms_list = proms_list[:1] resps_list = resps_list[:1] + """ kwargs = { 'n_tokens': 1024, 'd_model': 1024, # 256, # 1024, # 1536 @@ -351,7 +352,6 @@ def example_usage(): 'n_layers': 12, 'n_experts': 8, } - """ """ try: @@ -397,6 +397,7 @@ def example_usage(): for i in t: stats = {"step": i} stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) + stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index d457b25..6ca48e2 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -83,31 +83,27 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: torch.optim.AdamW = AdamW torch.optim.SGD = SGD -# disgusting kludge, but it works +# disgusting kludge, but it works (just realized BitNet has its own replacement routine) def replace_linear( model ): device = next(model.parameters()).device - linears = [k.split('.') for k, m in model.named_modules() if type(m).__name__ == 'Linear'] + linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)] for *parent, k in linears: name = '.'.join(parent) # copy parameters - m = getattr( - model.get_submodule(name), - k - ) + m = getattr( model.get_submodule(name), k ) in_features = m.in_features out_features = m.out_features - bias = False if cfg.bitsandbytes.bitnet else m.bias # errors out with BitNet + bias = m.bias is not None - # overwright + # overwrite setattr( - model.get_submodule(name), - k, - Linear( in_features=in_features, out_features=out_features, bias=bias ).to(device) + model.get_submodule(name), k, + Linear( in_features=in_features, out_features=out_features, bias=bias ) ) - return model + return model.to(device) # because our now Linear is created on the CPU...... # https://github.com/konstmish/prodigy try: