wrapper fixes
This commit is contained in:
parent
aa1e25fbf5
commit
467fa1c5ee
|
@ -395,7 +395,7 @@ class Base(nn.Module):
|
|||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
|
||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||
))
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
|
@ -414,7 +414,7 @@ class Base(nn.Module):
|
|||
is_decoder=True,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
|
||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||
))
|
||||
elif self.arch_type == "llama":
|
||||
if n_experts <= 1:
|
||||
|
@ -431,7 +431,7 @@ class Base(nn.Module):
|
|||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
|
||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||
))
|
||||
else:
|
||||
self.model = MixtralModel(MixtralConfig(
|
||||
|
@ -450,7 +450,7 @@ class Base(nn.Module):
|
|||
is_decoder=True,
|
||||
num_local_experts=n_experts,
|
||||
num_experts_per_tok=min(2, n_experts),
|
||||
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
|
||||
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||
))
|
||||
|
||||
elif self.arch_type == "retnet":
|
||||
|
|
|
@ -101,6 +101,8 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
|||
|
||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||
def replace_linear( model ):
|
||||
bnb = cfg.bitsandbytes.enabled and cfg.bitsandbytes.linear and not cfg.bitsandbytes.bitnet
|
||||
|
||||
device = next(model.parameters()).device
|
||||
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
||||
for *parent, k in linears:
|
||||
|
@ -113,13 +115,15 @@ def replace_linear( model ):
|
|||
out_features = m.out_features
|
||||
bias = m.bias is not None
|
||||
|
||||
kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias)
|
||||
|
||||
# overwrite
|
||||
setattr(
|
||||
model.get_submodule(name), k,
|
||||
Linear( in_features=in_features, out_features=out_features, bias=bias )
|
||||
Linear( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
||||
)
|
||||
|
||||
return model.to(device) # because our now Linear is created on the CPU......
|
||||
return model
|
||||
|
||||
# https://github.com/konstmish/prodigy
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue
Block a user