diff --git a/vall_e/demo.py b/vall_e/demo.py index 9a0f557..0c4d8b1 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -64,6 +64,7 @@ def main(): parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=0) + parser.add_argument("--min-p", type=float, default=0.0) parser.add_argument("--repetition-penalty", type=float, default=1.0) parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0) @@ -121,8 +122,15 @@ def main(): comparison_kwargs["enabled"] = True comparison_kwargs["suffix"] = "_entropix" comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] + comparison_kwargs["before"]["entropix_sampling"] = True + comparison_kwargs["before"]["ar_temp"] = 0.666 + comparison_kwargs["before"]["top_k"] = 27 + comparison_kwargs["before"]["top_p"] = 0.9 comparison_kwargs["after"]["entropix_sampling"] = False + comparison_kwargs["after"]["ar_temp"] = args.ar_temp + comparison_kwargs["after"]["top_k"] = args.top_k + comparison_kwargs["after"]["top_p"] = args.top_p # read html template @@ -135,7 +143,7 @@ def main(): max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, ar_temp=args.ar_temp, nar_temp=args.nar_temp, min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, - top_p=args.top_p, top_k=args.top_k, + top_p=args.top_p, top_k=args.top_k, min_p=args.min_p, repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width, diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 2285bcd..d7ec30b 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -150,7 +150,7 @@ class LlamaAttention_Adapted(LlamaAttention): super().__init__(*args, **kwargs) - # Adapted from LlamaAttention.forward + # Adapted from LlamaAttention.forward, this doesn't seem to give great output...... def _forward( self, hidden_states: torch.Tensor, @@ -163,41 +163,18 @@ class LlamaAttention_Adapted(LlamaAttention): position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + dropout_rate = self.attention_dropout if self.training else 0.0 bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE embeddings internally " - "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " - "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " - "removed and `position_embeddings` will be mandatory." - ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings @@ -259,7 +236,7 @@ class LlamaAttention_Adapted(LlamaAttention): position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions or not self.mode: + if not self.mode: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. return self._forward( hidden_states=hidden_states, @@ -295,6 +272,9 @@ class LlamaAttention_Adapted(LlamaAttention): key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) if self.mode in ["xformers", "flash_attn"]: + if output_attentions: + attn_scores = torch.matmul(query_states, repeat_kv(key_states, self.num_key_value_groups).transpose(2, 3)) / math.sqrt(self.head_dim) + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. query_states = query_states.transpose(1, 2) @@ -346,10 +326,11 @@ class LlamaAttention_Adapted(LlamaAttention): attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, attn_scores, past_key_value key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if output_attentions else None causal_mask = attention_mask if attention_mask is not None: @@ -391,4 +372,4 @@ class LlamaAttention_Adapted(LlamaAttention): attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value \ No newline at end of file + return attn_output, attn_scores, past_key_value \ No newline at end of file