From b482ca19ff9540717fb9a7ebbd4ea41a7f5ab132 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 31 May 2024 19:32:37 -0500 Subject: [PATCH] added model config option to set KV head count for MQA/GQA instead of MHA for llama-based models (i think its very negligible both ways on such a small model size) --- vall_e/config.py | 1 + vall_e/models/base.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 90bf6b8..8df7b57 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -214,6 +214,7 @@ class Model: audio_embedding_sums: bool = True dropout: float = 0.1 # adjustable dropout value loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) + kv_heads: int = 4 def get(self, name=None): return [ self ] if not name or self.name == name else [] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b830eb5..f0a019e 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -538,7 +538,7 @@ class Base(nn.Module): num_hidden_layers=n_layers, num_attention_heads=n_heads, attention_dropout=p_dropout if training else 0.0, - num_key_value_heads=n_heads, + num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, hidden_act="gelu", is_encoder_decoder=False, is_decoder=True, @@ -554,7 +554,7 @@ class Base(nn.Module): num_hidden_layers=n_layers, num_attention_heads=n_heads, attention_dropout=p_dropout if training else 0.0, - num_key_value_heads=n_heads, + num_key_value_heads=self.config.kv_heads if self.config.kv_heads > 0 else n_heads, sliding_window=75 * 12, # 12 second context window output_router_logits=training, hidden_act="gelu",