added model config option to set KV head count for MQA/GQA instead of MHA for llama-based models (i think its very negligible both ways on such a small model size)
This commit is contained in:
parent
e15c6c74c3
commit
b482ca19ff
|
@ -214,6 +214,7 @@ class Model:
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = True
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
|
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
|
||||||
|
kv_heads: int = 4
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
return [ self ] if not name or self.name == name else []
|
return [ self ] if not name or self.name == name else []
|
||||||
|
|
|
@ -538,7 +538,7 @@ class Base(nn.Module):
|
||||||
num_hidden_layers=n_layers,
|
num_hidden_layers=n_layers,
|
||||||
num_attention_heads=n_heads,
|
num_attention_heads=n_heads,
|
||||||
attention_dropout=p_dropout if training else 0.0,
|
attention_dropout=p_dropout if training else 0.0,
|
||||||
num_key_value_heads=n_heads,
|
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
is_encoder_decoder=False,
|
is_encoder_decoder=False,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
|
@ -554,7 +554,7 @@ class Base(nn.Module):
|
||||||
num_hidden_layers=n_layers,
|
num_hidden_layers=n_layers,
|
||||||
num_attention_heads=n_heads,
|
num_attention_heads=n_heads,
|
||||||
attention_dropout=p_dropout if training else 0.0,
|
attention_dropout=p_dropout if training else 0.0,
|
||||||
num_key_value_heads=n_heads,
|
num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads,
|
||||||
sliding_window=75 * 12, # 12 second context window
|
sliding_window=75 * 12, # 12 second context window
|
||||||
output_router_logits=training,
|
output_router_logits=training,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user