From 0841f366e86379d17c24f8c935e349c78759ded4 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 28 Jan 2025 21:55:05 -0600 Subject: [PATCH] I should really just grab modelling_llama wholesale (fix for the adapted attention class) --- vall_e/models/arch/llama.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 70854bf..5cc8112 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -35,6 +35,11 @@ class LlamaAttention_Adapted(LlamaAttention): super().__init__(*args, **kwargs) + if not hasattr(self, "num_heads"): + self.num_heads = self.config.num_attention_heads + if not hasattr(self, "num_key_value_heads"): + self.num_key_value_heads = self.config.num_key_value_heads + # extracts inputs from a batch based on requested causality def split_forward( self,