I should really just grab modelling_llama wholesale (fix for the adapted attention class)

This commit is contained in:
mrq 2025-01-28 21:55:05 -06:00
parent e5f9da2221
commit 0841f366e8

View File

@ -35,6 +35,11 @@ class LlamaAttention_Adapted(LlamaAttention):
super().__init__(*args, **kwargs)
if not hasattr(self, "num_heads"):
self.num_heads = self.config.num_attention_heads
if not hasattr(self, "num_key_value_heads"):
self.num_key_value_heads = self.config.num_key_value_heads
# extracts inputs from a batch based on requested causality
def split_forward(
self,