From a755eb3c62bd595893445bd55562f9195051b3c1 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 11 May 2024 17:34:45 -0500 Subject: [PATCH] ugh --- vall_e/models/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index dc12392..55a84e0 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -207,13 +207,13 @@ try: if self.mode == "xformers": if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: - attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None) + attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=None, p=dropout_rate) else: - attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask()) + attn_output = memory_efficient_attention(query_states, key_states, value_states, attn_bias=LowerTriangularMask(), p=dropout_rate) else: #torch.nn.attention.sdpa_kernel with torch.backends.cuda.sdp_kernel(enable_flash=self.mode == "flash", enable_math=self.mode == "math", enable_mem_efficient=self.mode == "mem_efficient"): - attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask) + attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=dropout_rate) attn_weights = None