cleaner replacement code (because I realized BitNet had an implementation for it too), added calculating gradient norm and performing gradient clipping in local trainer (non-deepspeed)
This commit is contained in:
parent
47435207f7
commit
f3c59c3e7e
|
@ -103,6 +103,10 @@ class Engine():
|
||||||
def gradient_accumulation_steps(self):
|
def gradient_accumulation_steps(self):
|
||||||
return cfg.hyperparameters.gradient_accumulation_steps
|
return cfg.hyperparameters.gradient_accumulation_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gradient_clipping(self):
|
||||||
|
return cfg.hyperparameters.gradient_clipping
|
||||||
|
|
||||||
def gather_attribute(self, *args, **kwargs):
|
def gather_attribute(self, *args, **kwargs):
|
||||||
return gather_attribute(self.module, *args, **kwargs)
|
return gather_attribute(self.module, *args, **kwargs)
|
||||||
|
|
||||||
|
@ -186,10 +190,17 @@ class Engine():
|
||||||
self.global_samples += self.batch_size
|
self.global_samples += self.batch_size
|
||||||
|
|
||||||
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
|
if (self.micro_steps + 1) % max(1, self.gradient_accumulation_steps) == 0:
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.module.parameters(), self.gradient_clipping)
|
||||||
|
|
||||||
self.global_steps += 1
|
self.global_steps += 1
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
self._get_grad_norm()
|
||||||
|
|
||||||
|
def _get_grad_norm(self):
|
||||||
|
self._global_grad_norm = torch.cat([ param.grad.detach().flatten() for param in self.module.parameters() if param.grad is not None ]).norm().item()
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
lrs = []
|
lrs = []
|
||||||
for param_group in self.optimizer.param_groups:
|
for param_group in self.optimizer.param_groups:
|
||||||
|
@ -207,7 +218,7 @@ class Engine():
|
||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
|
|
||||||
def get_global_grad_norm(self):
|
def get_global_grad_norm(self):
|
||||||
return 0.0
|
return self._global_grad_norm
|
||||||
|
|
||||||
def traverse(self, *args, **kwargs):
|
def traverse(self, *args, **kwargs):
|
||||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||||
|
|
|
@ -336,6 +336,7 @@ def example_usage():
|
||||||
proms_list = proms_list[:1]
|
proms_list = proms_list[:1]
|
||||||
resps_list = resps_list[:1]
|
resps_list = resps_list[:1]
|
||||||
|
|
||||||
|
"""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'n_tokens': 1024,
|
'n_tokens': 1024,
|
||||||
'd_model': 1024, # 256, # 1024, # 1536
|
'd_model': 1024, # 256, # 1024, # 1536
|
||||||
|
@ -351,7 +352,6 @@ def example_usage():
|
||||||
'n_layers': 12,
|
'n_layers': 12,
|
||||||
'n_experts': 8,
|
'n_experts': 8,
|
||||||
}
|
}
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
@ -397,6 +397,7 @@ def example_usage():
|
||||||
for i in t:
|
for i in t:
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||||
|
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||||
|
|
||||||
tqdm.write(f"{stats}")
|
tqdm.write(f"{stats}")
|
||||||
|
|
||||||
|
|
|
@ -83,31 +83,27 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
||||||
torch.optim.AdamW = AdamW
|
torch.optim.AdamW = AdamW
|
||||||
torch.optim.SGD = SGD
|
torch.optim.SGD = SGD
|
||||||
|
|
||||||
# disgusting kludge, but it works
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||||
def replace_linear( model ):
|
def replace_linear( model ):
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
linears = [k.split('.') for k, m in model.named_modules() if type(m).__name__ == 'Linear']
|
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
||||||
for *parent, k in linears:
|
for *parent, k in linears:
|
||||||
name = '.'.join(parent)
|
name = '.'.join(parent)
|
||||||
|
|
||||||
# copy parameters
|
# copy parameters
|
||||||
m = getattr(
|
m = getattr( model.get_submodule(name), k )
|
||||||
model.get_submodule(name),
|
|
||||||
k
|
|
||||||
)
|
|
||||||
|
|
||||||
in_features = m.in_features
|
in_features = m.in_features
|
||||||
out_features = m.out_features
|
out_features = m.out_features
|
||||||
bias = False if cfg.bitsandbytes.bitnet else m.bias # errors out with BitNet
|
bias = m.bias is not None
|
||||||
|
|
||||||
# overwright
|
# overwrite
|
||||||
setattr(
|
setattr(
|
||||||
model.get_submodule(name),
|
model.get_submodule(name), k,
|
||||||
k,
|
Linear( in_features=in_features, out_features=out_features, bias=bias )
|
||||||
Linear( in_features=in_features, out_features=out_features, bias=bias ).to(device)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model.to(device) # because our now Linear is created on the CPU......
|
||||||
|
|
||||||
# https://github.com/konstmish/prodigy
|
# https://github.com/konstmish/prodigy
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user