added much cleaner non-causal mask generation

This commit is contained in:
mrq 2024-11-22 19:43:32 -06:00
parent c99a74e834
commit 41d7c30ea5

View File

@ -589,48 +589,28 @@ class LlamaModel_Adapted(LlamaModel):
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
def _update_noncausal_mask(
self, attention_mask, inputs_embeds, past_key_values_length
self,
attention_mask,
inputs_embeds,
past_key_values_length,
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
input_shape = (inputs_embeds.shape[0], inputs_embeds.shape[1])
def _expand_mask(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
bsz, seq_len, _ = inputs_embeds.size()
# generate default mask based on input
if attention_mask is None:
attention_mask = torch.ones( input_shape, dtype=torch.bool, device=inputs_embeds.device )
attention_mask = torch.ones( (bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device )
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
# make square
expanded_mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to( dtype=inputs_embeds.dtype )
return combined_attention_mask
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
# gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems
def _update_causal_mask(
self,
attention_mask: torch.Tensor,