ugh
This commit is contained in:
parent
88e9b9caff
commit
a755eb3c62
|
@ -207,13 +207,13 @@ try:
|
||||||
|
|
||||||
if self.mode == "xformers":
|
if self.mode == "xformers":
|
||||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None, p=dropout_rate)
|
||||||
else:
|
else:
|
||||||
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask())
|
attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask(), p=dropout_rate)
|
||||||
else:
|
else:
|
||||||
#torch.nn.attention.sdpa_kernel
|
#torch.nn.attention.sdpa_kernel
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
|
with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"):
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask)
|
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_rate)
|
||||||
|
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user