master
mrq 2024-03-01 20:38:06 +07:00
parent f3c59c3e7e
commit 91062361af
2 changed files with 4 additions and 2 deletions

@ -199,7 +199,8 @@ class Engine():
self._get_grad_norm()
def _get_grad_norm(self):
self._global_grad_norm = torch.cat([ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]).norm().item()
t = [ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]
self._global_grad_norm = torch.cat(t).norm().item() if len(t) else 0
def get_lr(self):
lrs = []

@ -336,7 +336,7 @@ def example_usage():
proms_list = proms_list[:1]
resps_list = resps_list[:1]
"""
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = {
'n_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
@ -352,6 +352,7 @@ def example_usage():
'n_layers': 12,
'n_experts': 8,
}
"""
"""
try: