respect attention defined in the yaml for web UI (which might explain why theres been a discrepancy in outputs for me)
This commit is contained in:
parent
ed6b7a690f
commit
c800d28bb8
|
@ -184,6 +184,8 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
||||||
|
|
||||||
|
attn_scores = None
|
||||||
|
|
||||||
if mode in ["xformers", "flash_attn"]:
|
if mode in ["xformers", "flash_attn"]:
|
||||||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
# to be able to avoid many of these transpose/reshape/view.
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
@ -241,9 +243,6 @@ class LlamaAttention_Adapted(LlamaAttention):
|
||||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
# to-do: actually find what is our attention scores, since these seem to not vary at all
|
|
||||||
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if output_attentions else None
|
|
||||||
|
|
||||||
causal_mask = attention_mask
|
causal_mask = attention_mask
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
|
@ -543,8 +543,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
if attention_backend == "flash_attn_v100":
|
if attention_backend == "flash_attn_v100":
|
||||||
self.l_padding = 32
|
self.l_padding = 32
|
||||||
|
elif attention_backend == "fused_attn":
|
||||||
if attention_backend == "fused_attn":
|
|
||||||
self.l_padding = 128
|
self.l_padding = 128
|
||||||
|
|
||||||
if self.arch_type == "transformer":
|
if self.arch_type == "transformer":
|
||||||
|
|
|
@ -107,7 +107,7 @@ def load_sample( speaker ):
|
||||||
|
|
||||||
return data, (sr, wav)
|
return data, (sr, wav)
|
||||||
|
|
||||||
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"):
|
def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=None):
|
||||||
global tts
|
global tts
|
||||||
|
|
||||||
if tts is not None:
|
if tts is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user