From 467fa1c5ee8e8af4f660d0103add52f6dc71971c Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 16 Apr 2024 10:19:02 -0500 Subject: [PATCH] wrapper fixes --- vall_e/models/base.py | 8 ++++---- vall_e/utils/wrapper.py | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index e2cfa2b..3255c0a 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -395,7 +395,7 @@ class Base(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, - attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None + attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", )) else: self.model = MixtralModel(MixtralConfig( @@ -414,7 +414,7 @@ class Base(nn.Module): is_decoder=True, num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), - attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None + attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", )) elif self.arch_type == "llama": if n_experts <= 1: @@ -431,7 +431,7 @@ class Base(nn.Module): hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, - attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None + attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", )) else: self.model = MixtralModel(MixtralConfig( @@ -450,7 +450,7 @@ class Base(nn.Module): is_decoder=True, num_local_experts=n_experts, num_experts_per_tok=min(2, n_experts), - attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None + attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2", )) elif self.arch_type == "retnet": diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 1a8e122..1741a8e 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -101,6 +101,8 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: # disgusting kludge, but it works (just realized BitNet has its own replacement routine) def replace_linear( model ): + bnb = cfg.bitsandbytes.enabled and cfg.bitsandbytes.linear and not cfg.bitsandbytes.bitnet + device = next(model.parameters()).device linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)] for *parent, k in linears: @@ -113,13 +115,15 @@ def replace_linear( model ): out_features = m.out_features bias = m.bias is not None + kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias) + # overwrite setattr( model.get_submodule(name), k, - Linear( in_features=in_features, out_features=out_features, bias=bias ) + Linear( **kwargs ).to(device=device, dtype=cfg.trainer.dtype) ) - return model.to(device) # because our now Linear is created on the CPU...... + return model # https://github.com/konstmish/prodigy try: