diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 59b7b83..03093a7 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -589,48 +589,28 @@ class LlamaModel_Adapted(LlamaModel): # shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker def _update_noncausal_mask( - self, attention_mask, inputs_embeds, past_key_values_length + self, + attention_mask, + inputs_embeds, + past_key_values_length, ): # create noncausal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] - input_shape = (inputs_embeds.shape[0], inputs_embeds.shape[1]) - - def _expand_mask( - mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None - ): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = ( - mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - ) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) + bsz, seq_len, _ = inputs_embeds.size() + # generate default mask based on input if attention_mask is None: - attention_mask = torch.ones( input_shape, dtype=torch.bool, device=inputs_embeds.device ) + attention_mask = torch.ones( (bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device ) - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) + # make square + expanded_mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to( dtype=inputs_embeds.dtype ) - return combined_attention_mask + # invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked + inverted_mask = 1.0 - expanded_mask + return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min ) + # gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems def _update_causal_mask( self, attention_mask: torch.Tensor,