diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index a0bb626..7793868 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -199,7 +199,8 @@ class Engine(): self._get_grad_norm() def _get_grad_norm(self): - self._global_grad_norm = torch.cat([ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]).norm().item() + t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ] + self._global_grad_norm = torch.cat(t).norm().item() if len(t) else 0 def get_lr(self): lrs = [] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index dc3b202..0984330 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -336,7 +336,7 @@ def example_usage(): proms_list = proms_list[:1] resps_list = resps_list[:1] - """ + # rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise kwargs = { 'n_tokens': 1024, 'd_model': 1024, # 256, # 1024, # 1536 @@ -352,6 +352,7 @@ def example_usage(): 'n_layers': 12, 'n_experts': 8, } + """ """ try: