wrapper fixes

This commit is contained in:
mrq 2024-04-16 10:19:02 -05:00
parent aa1e25fbf5
commit 467fa1c5ee
2 changed files with 10 additions and 6 deletions

View File

@ -395,7 +395,7 @@ class Base(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
))
else:
self.model = MixtralModel(MixtralConfig(
@ -414,7 +414,7 @@ class Base(nn.Module):
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
))
elif self.arch_type == "llama":
if n_experts <= 1:
@ -431,7 +431,7 @@ class Base(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
))
else:
self.model = MixtralModel(MixtralConfig(
@ -450,7 +450,7 @@ class Base(nn.Module):
is_decoder=True,
num_local_experts=n_experts,
num_experts_per_tok=min(2, n_experts),
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
))
elif self.arch_type == "retnet":

View File

@ -101,6 +101,8 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
def replace_linear( model ):
bnb = cfg.bitsandbytes.enabled and cfg.bitsandbytes.linear and not cfg.bitsandbytes.bitnet
device = next(model.parameters()).device
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
for *parent, k in linears:
@ -113,13 +115,15 @@ def replace_linear( model ):
out_features = m.out_features
bias = m.bias is not None
kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias)
# overwrite
setattr(
model.get_submodule(name), k,
Linear( in_features=in_features, out_features=out_features, bias=bias )
Linear( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
)
return model.to(device) # because our now Linear is created on the CPU......
return model
# https://github.com/konstmish/prodigy
try: