Added cfg.bitsandbytes.replace as a less intrusive alternative to cfg.bitsandbytes.inject to replace all Linear modules in a model

This commit is contained in:
mrq 2024-03-01 19:20:10 -06:00
parent 0427d8d076
commit 47435207f7
5 changed files with 41 additions and 8 deletions

View File

@ -303,7 +303,7 @@ class Models:
class Hyperparameters:
batch_size: int = 8
gradient_accumulation_steps: int = 32
gradient_clipping: int = 100
gradient_clipping: int | float = 100
optimizer: str = "Adamw"
torch_optimizer: bool = False
@ -532,6 +532,7 @@ class Inference:
class BitsAndBytes:
enabled: bool = False
injects: bool = False
replace: bool = False
linear: bool = True
embedding: bool = True

View File

@ -44,6 +44,9 @@ def load_engines(training=True):
if inferencing:
model._cfg.training = False
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
model.model = ml.replace_linear( model.model )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
optimizer_class = None
params = {

View File

@ -341,7 +341,7 @@ def example_usage():
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_experts': 8,
'n_experts': 1,
}
"""
kwargs = {
@ -362,10 +362,13 @@ def example_usage():
model = AR_NAR(**kwargs).to(device)
steps = 500
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer)
if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace:
model.model = ml.replace_linear( model.model )
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )

View File

@ -80,8 +80,8 @@ class Attention(nn.Module):
self.n_heads = n_heads
self.scale = dim_head**-0.5
self.to_qkv = nn.Linear(d_model, d_model * 3, bias=False)
self.to_out = nn.Linear(d_model, d_model)
self.to_qkv = ml.Linear(d_model, d_model * 3, bias=False)
self.to_out = ml.Linear(d_model, d_model)
def forward(self, x, m):
"""
@ -169,10 +169,10 @@ class Block(nn.Sequential):
n_ff = d_model * 4 # 1024 * 4 = 4096 feed-forwards
self.ffn = PrenormResidual(
nn.Sequential(
nn.Linear(d_model, n_ff),
ml.Linear(d_model, n_ff),
nn.GELU(),
nn.Dropout(p_dropout),
nn.Linear(n_ff, d_model),
ml.Linear(n_ff, d_model),
),
d_model=d_model,
p_dropout=p_dropout,

View File

@ -83,6 +83,32 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
torch.optim.AdamW = AdamW
torch.optim.SGD = SGD
# disgusting kludge, but it works
def replace_linear( model ):
device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if type(m).__name__ == 'Linear']
for *parent, k in linears:
name = '.'.join(parent)
# copy parameters
m = getattr(
model.get_submodule(name),
k
)
in_features = m.in_features
out_features = m.out_features
bias = False if cfg.bitsandbytes.bitnet else m.bias # errors out with BitNet
# overwright
setattr(
model.get_submodule(name),
k,
Linear( in_features=in_features, out_features=out_features, bias=bias ).to(device)
)
return model
# https://github.com/konstmish/prodigy
try:
from prodigyopt import Prodigy