I should really just grab modelling_llama wholesale (fix for the adapted attention class)
This commit is contained in:
parent
e5f9da2221
commit
0841f366e8
|
@ -35,6 +35,11 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if not hasattr(self, "num_heads"):
|
||||
self.num_heads = self.config.num_attention_heads
|
||||
if not hasattr(self, "num_key_value_heads"):
|
||||
self.num_key_value_heads = self.config.num_key_value_heads
|
||||
|
||||
# extracts inputs from a batch based on requested causality
|
||||
def split_forward(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue
Block a user