added much cleaner non-causal mask generation
This commit is contained in:
parent
c99a74e834
commit
41d7c30ea5
|
@ -589,48 +589,28 @@ class LlamaModel_Adapted(LlamaModel):
|
|||
|
||||
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
|
||||
def _update_noncausal_mask(
|
||||
self, attention_mask, inputs_embeds, past_key_values_length
|
||||
self,
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
):
|
||||
# create noncausal mask
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
|
||||
input_shape = (inputs_embeds.shape[0], inputs_embeds.shape[1])
|
||||
|
||||
def _expand_mask(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
expanded_mask = (
|
||||
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
||||
)
|
||||
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||
)
|
||||
bsz, seq_len, _ = inputs_embeds.size()
|
||||
|
||||
# generate default mask based on input
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones( input_shape, dtype=torch.bool, device=inputs_embeds.device )
|
||||
attention_mask = torch.ones( (bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device )
|
||||
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
expanded_attn_mask = _expand_mask(
|
||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||
).to(inputs_embeds.device)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask
|
||||
if combined_attention_mask is None
|
||||
else expanded_attn_mask + combined_attention_mask
|
||||
)
|
||||
# make square
|
||||
expanded_mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to( dtype=inputs_embeds.dtype )
|
||||
|
||||
return combined_attention_mask
|
||||
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
|
||||
|
||||
# gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
|
|
Loading…
Reference in New Issue
Block a user