wrapper fixes
This commit is contained in:
parent
aa1e25fbf5
commit
467fa1c5ee
|
@ -395,7 +395,7 @@ class Base(nn.Module):
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
|
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
self.model = MixtralModel(MixtralConfig(
|
self.model = MixtralModel(MixtralConfig(
|
||||||
|
@ -414,7 +414,7 @@ class Base(nn.Module):
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
num_local_experts=n_experts,
|
num_local_experts=n_experts,
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
|
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||||
))
|
))
|
||||||
elif self.arch_type == "llama":
|
elif self.arch_type == "llama":
|
||||||
if n_experts <= 1:
|
if n_experts <= 1:
|
||||||
|
@ -431,7 +431,7 @@ class Base(nn.Module):
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
|
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||||
))
|
))
|
||||||
else:
|
else:
|
||||||
self.model = MixtralModel(MixtralConfig(
|
self.model = MixtralModel(MixtralConfig(
|
||||||
|
@ -450,7 +450,7 @@ class Base(nn.Module):
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
num_local_experts=n_experts,
|
num_local_experts=n_experts,
|
||||||
num_experts_per_tok=min(2, n_experts),
|
num_experts_per_tok=min(2, n_experts),
|
||||||
attn_implementation=self.config.attention if self.config is not None else "flash_attention_2", # None
|
attn_implementation=self.config.attention if self.config is not None else None, # "flash_attention_2",
|
||||||
))
|
))
|
||||||
|
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
|
|
|
@ -101,6 +101,8 @@ if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
||||||
|
|
||||||
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
# disgusting kludge, but it works (just realized BitNet has its own replacement routine)
|
||||||
def replace_linear( model ):
|
def replace_linear( model ):
|
||||||
|
bnb = cfg.bitsandbytes.enabled and cfg.bitsandbytes.linear and not cfg.bitsandbytes.bitnet
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
linears = [k.split('.') for k, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
|
||||||
for *parent, k in linears:
|
for *parent, k in linears:
|
||||||
|
@ -113,13 +115,15 @@ def replace_linear( model ):
|
||||||
out_features = m.out_features
|
out_features = m.out_features
|
||||||
bias = m.bias is not None
|
bias = m.bias is not None
|
||||||
|
|
||||||
|
kwargs = dict(in_features=in_features, out_features=out_features, bias=bias) if not bnb else dict(input_features=in_features, output_features=out_features, bias=bias)
|
||||||
|
|
||||||
# overwrite
|
# overwrite
|
||||||
setattr(
|
setattr(
|
||||||
model.get_submodule(name), k,
|
model.get_submodule(name), k,
|
||||||
Linear( in_features=in_features, out_features=out_features, bias=bias )
|
Linear( **kwargs ).to(device=device, dtype=cfg.trainer.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
return model.to(device) # because our now Linear is created on the CPU......
|
return model
|
||||||
|
|
||||||
# https://github.com/konstmish/prodigy
|
# https://github.com/konstmish/prodigy
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user