I should really just grab modelling_llama wholesale (fix for the adapted attention class)
This commit is contained in:
parent
e5f9da2221
commit
0841f366e8
|
@ -35,6 +35,11 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
|
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if not hasattr(self, "num_heads"):
|
||||||
|
self.num_heads = self.config.num_attention_heads
|
||||||
|
if not hasattr(self, "num_key_value_heads"):
|
||||||
|
self.num_key_value_heads = self.config.num_key_value_heads
|
||||||
|
|
||||||
# extracts inputs from a batch based on requested causality
|
# extracts inputs from a batch based on requested causality
|
||||||
def split_forward(
|
def split_forward(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user