cleaner replacement code (because I realized BitNet had an implementation for it too), added calculating gradient norm and performing gradient clipping in local trainer (non-deepspeed)

This commit is contained in:
mrq 2024-03-01 20:18:43 -06:00
parent 47435207f7
commit f3c59c3e7e
3 changed files with 22 additions and 14 deletions

View File

@ -103,6 +103,10 @@ class Engine():
def gradient_accumulation_steps(self): def gradient_accumulation_steps(self):
return cfg.hyperparameters.gradient_accumulation_steps return cfg.hyperparameters.gradient_accumulation_steps
@property
def gradient_clipping(self):
return cfg.hyperparameters.gradient_clipping
def gather_attribute(self, *args, **kwargs): def gather_attribute(self, *args, **kwargs):
return gather_attribute(self.module, *args, **kwargs) return gather_attribute(self.module, *args, **kwargs)
@ -186,10 +190,17 @@ class Engine():
self.global_samples += self.batch_size self.global_samples += self.batch_size
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0: if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
self.global_steps += 1 self.global_steps += 1
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self._get_grad_norm()
def _get_grad_norm(self):
self._global_grad_norm = torch.cat([ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]).norm().item()
def get_lr(self): def get_lr(self):
lrs = [] lrs = []
for param_group in self.optimizer.param_groups: for param_group in self.optimizer.param_groups:
@ -207,7 +218,7 @@ class Engine():
param_group['lr'] = lr param_group['lr'] = lr
def get_global_grad_norm(self): def get_global_grad_norm(self):
return 0.0 return self._global_grad_norm
def traverse(self, *args, **kwargs): def traverse(self, *args, **kwargs):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):

View File

@ -336,6 +336,7 @@ def example_usage():
proms_list = proms_list[:1] proms_list = proms_list[:1]
resps_list = resps_list[:1] resps_list = resps_list[:1]
"""
kwargs = { kwargs = {
'n_tokens': 1024, 'n_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536 'd_model': 1024, # 256, # 1024, # 1536
@ -351,7 +352,6 @@ def example_usage():
'n_layers': 12, 'n_layers': 12,
'n_experts': 8, 'n_experts': 8,
} }
"""
""" """
try: try:
@ -397,6 +397,7 @@ def example_usage():
for i in t: for i in t:
stats = {"step": i} stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}") tqdm.write(f"{stats}")

View File

@ -83,31 +83,27 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
torch.optim.AdamW = AdamW torch.optim.AdamW = AdamW
torch.optim.SGD = SGD torch.optim.SGD = SGD
# disgusting kludge, but it works # disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear( model ): def replace_linear( model ):
device = next(model.parameters()).device device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if type(m).__name__ == 'Linear'] linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
for *parent, k in linears: for *parent, k in linears:
name = '.'.join(parent) name = '.'.join(parent)
# copy parameters # copy parameters
m = getattr( m = getattr( model.get_submodule(name), k )
model.get_submodule(name),
k
)
in_features = m.in_features in_features = m.in_features
out_features = m.out_features out_features = m.out_features
bias = False if cfg.bitsandbytes.bitnet else m.bias # errors out with BitNet bias = m.bias is not None
# overwright # overwrite
setattr( setattr(
model.get_submodule(name), model.get_submodule(name), k,
k, Linear( in_features=in_features, out_features=out_features, bias=bias )
Linear( in_features=in_features, out_features=out_features, bias=bias ).to(device)
) )
return model return model.to(device) # because our now Linear is created on the CPU......
# https://github.com/konstmish/prodigy # https://github.com/konstmish/prodigy
try: try: