respect attention defined in the yaml for web UI (which might explain why theres been a discrepancy in outputs for me)

This commit is contained in:
mrq 2024-10-13 11:02:24 -05:00
parent ed6b7a690f
commit c800d28bb8
3 changed files with 4 additions and 6 deletions

View File

@ -184,6 +184,8 @@ class LlamaAttention_Adapted(LlamaAttention):
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_scores = None
if mode in ["xformers", "flash_attn"]:
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
@ -241,9 +243,6 @@ class LlamaAttention_Adapted(LlamaAttention):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# to-do: actually find what is our attention scores, since these seem to not vary at all
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if output_attentions else None
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

View File

@ -543,8 +543,7 @@ class Base(nn.Module):
if attention_backend == "flash_attn_v100":
self.l_padding = 32
if attention_backend == "fused_attn":
elif attention_backend == "fused_attn":
self.l_padding = 128
if self.arch_type == "transformer":

View File

@ -107,7 +107,7 @@ def load_sample( speaker ):
return data, (sr, wav)
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"):
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=None):
global tts
if tts is not None: