From 47435207f7ef065268f9e1e560d3ab1e2783c0f8 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 1 Mar 2024 19:20:10 -0600 Subject: [PATCH] Added cfg.bitsandbytes.replace as a less intrusive alternative to cfg.bitsandbytes.inject to replace all Linear modules in a model --- vall_e/config.py | 3 ++- vall_e/engines/__init__.py | 3 +++ vall_e/models/ar_nar.py | 9 ++++++--- vall_e/models/transformer.py | 8 ++++---- vall_e/utils/wrapper.py | 26 ++++++++++++++++++++++++++ 5 files changed, 41 insertions(+), 8 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index e3a733f..e4df91c 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -303,7 +303,7 @@ class Models: class Hyperparameters: batch_size: int = 8 gradient_accumulation_steps: int = 32 - gradient_clipping: int = 100 + gradient_clipping: int | float = 100 optimizer: str = "Adamw" torch_optimizer: bool = False @@ -532,6 +532,7 @@ class Inference: class BitsAndBytes: enabled: bool = False injects: bool = False + replace: bool = False linear: bool = True embedding: bool = True diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 940acf3..0a0eb05 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -44,6 +44,9 @@ def load_engines(training=True): if inferencing: model._cfg.training = False + if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: + model.model = ml.replace_linear( model.model ) + if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): optimizer_class = None params = { diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 3b5d875..ef4a229 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -341,7 +341,7 @@ def example_usage(): 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 'n_layers': 12, # 32 - 'n_experts': 8, + 'n_experts': 1, } """ kwargs = { @@ -362,10 +362,13 @@ def example_usage(): model = AR_NAR(**kwargs).to(device) steps = 500 - #optimizer = ml.Prodigy(model.parameters(), lr=1.0) - optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) + optimizer = ml.Prodigy(model.parameters(), lr=1.0) + #optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) engine = Engine(model=model, optimizer=optimizer) + if cfg.bitsandbytes.enabled and cfg.bitsandbytes.replace: + model.model = ml.replace_linear( model.model ) + torch.save( { 'module': model.state_dict() }, "./data/test.pth" ) diff --git a/vall_e/models/transformer.py b/vall_e/models/transformer.py index 2147839..4d749c1 100755 --- a/vall_e/models/transformer.py +++ b/vall_e/models/transformer.py @@ -80,8 +80,8 @@ class Attention(nn.Module): self.n_heads = n_heads self.scale = dim_head**-0.5 - self.to_qkv = nn.Linear(d_model, d_model * 3, bias=False) - self.to_out = nn.Linear(d_model, d_model) + self.to_qkv = ml.Linear(d_model, d_model * 3, bias=False) + self.to_out = ml.Linear(d_model, d_model) def forward(self, x, m): """ @@ -169,10 +169,10 @@ class Block(nn.Sequential): n_ff = d_model * 4 # 1024 * 4 = 4096 feed-forwards self.ffn = PrenormResidual( nn.Sequential( - nn.Linear(d_model, n_ff), + ml.Linear(d_model, n_ff), nn.GELU(), nn.Dropout(p_dropout), - nn.Linear(n_ff, d_model), + ml.Linear(n_ff, d_model), ), d_model=d_model, p_dropout=p_dropout, diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index b00399d..d457b25 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -83,6 +83,32 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: torch.optim.AdamW = AdamW torch.optim.SGD = SGD +# disgusting kludge, but it works +def replace_linear( model ): + device = next(model.parameters()).device + linears = [k.split('.') for k, m in model.named_modules() if type(m).__name__ == 'Linear'] + for *parent, k in linears: + name = '.'.join(parent) + + # copy parameters + m = getattr( + model.get_submodule(name), + k + ) + + in_features = m.in_features + out_features = m.out_features + bias = False if cfg.bitsandbytes.bitnet else m.bias # errors out with BitNet + + # overwright + setattr( + model.get_submodule(name), + k, + Linear( in_features=in_features, out_features=out_features, bias=bias ).to(device) + ) + + return model + # https://github.com/konstmish/prodigy try: from prodigyopt import Prodigy