diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index c701dd8..cee3eed 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -223,9 +223,16 @@ class LlamaAttention_Adapted(LlamaAttention): **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: mode = "default" if output_attentions else self.mode + non_split_attention = [ + "default", + torch.nn.attention.SDPBackend.MATH, + torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + torch.nn.attention.SDPBackend.FLASH_ATTENTION, + torch.nn.attention.SDPBackend.CUDNN_ATTENTION + ] # split per batch because other attention mechanisms do not have a conditional is_causal per-batch, only for the entire input - if isinstance( is_causal, list ) and mode not in ["default"]: + if isinstance( is_causal, list ) and mode not in non_split_attention: # initialize lists attn_hidden_states = [ None for _ in is_causal ] self_attn_weights = [ None for _ in is_causal ] @@ -282,35 +289,6 @@ class LlamaAttention_Adapted(LlamaAttention): return attn_hidden_states, output_attentions, [] - """ - h_s = [] - s_a_w = [] - p_k_v = [] - - for i, state in enumerate(is_causal): - hidden_state, self_attn_weight, present_key_value = self.forward( - hidden_states=hidden_states[i].unsqueeze(0), - attention_mask=attention_mask[i].unsqueeze(0), - is_causal=state, - position_ids=position_ids[i].unsqueeze(0), - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=False, - cache_position=cache_position, - position_embeddings=(position_embeddings[0][i].unsqueeze(0), position_embeddings[1][i].unsqueeze(0)) if position_embeddings is not None else None, - **kwargs, - ) - h_s.append(hidden_state) - s_a_w.append(self_attn_weight) - p_k_v.append(present_key_value) - - return ( - torch.concat( h_s, dim=0 ), - torch.concat( s_a_w, dim=0 ) if s_a_w else None, - p_k_v, - ) - """ - dropout_rate = self.attention_dropout if self.training else 0.0 bsz, q_len, _ = hidden_states.size()