respect attention defined in the yaml for web UI (which might explain why theres been a discrepancy in outputs for me)
This commit is contained in:
parent
ed6b7a690f
commit
c800d28bb8
|
@ -184,6 +184,8 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||
|
||||
attn_scores = None
|
||||
|
||||
if mode in ["xformers", "flash_attn"]:
|
||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||
# to be able to avoid many of these transpose/reshape/view.
|
||||
|
@ -241,9 +243,6 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# to-do: actually find what is our attention scores, since these seem to not vary at all
|
||||
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if output_attentions else None
|
||||
|
||||
causal_mask = attention_mask
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
|
|
@ -543,8 +543,7 @@ class Base(nn.Module):
|
|||
|
||||
if attention_backend == "flash_attn_v100":
|
||||
self.l_padding = 32
|
||||
|
||||
if attention_backend == "fused_attn":
|
||||
elif attention_backend == "fused_attn":
|
||||
self.l_padding = 128
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
|
|
|
@ -107,7 +107,7 @@ def load_sample( speaker ):
|
|||
|
||||
return data, (sr, wav)
|
||||
|
||||
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"):
|
||||
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=None):
|
||||
global tts
|
||||
|
||||
if tts is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user