diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index fc27fe9..0c4c056 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -184,6 +184,8 @@ class LlamaAttention_Adapted(LlamaAttention): cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + attn_scores = None + if mode in ["xformers", "flash_attn"]: # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -241,9 +243,6 @@ class LlamaAttention_Adapted(LlamaAttention): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # to-do: actually find what is our attention scores, since these seem to not vary at all - attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if output_attentions else None - causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0780e26..778e208 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -543,8 +543,7 @@ class Base(nn.Module): if attention_backend == "flash_attn_v100": self.l_padding = 32 - - if attention_backend == "fused_attn": + elif attention_backend == "fused_attn": self.l_padding = 128 if self.arch_type == "transformer": diff --git a/vall_e/webui.py b/vall_e/webui.py index 468abd3..589af0a 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -107,7 +107,7 @@ def load_sample( speaker ): return data, (sr, wav) -def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="auto"): +def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention=None): global tts if tts is not None: