added much cleaner non-causal mask generation
This commit is contained in:
parent
c99a74e834
commit
41d7c30ea5
|
@ -589,48 +589,28 @@ class LlamaModel_Adapted(LlamaModel):
|
||||||
|
|
||||||
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
|
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
|
||||||
def _update_noncausal_mask(
|
def _update_noncausal_mask(
|
||||||
self, attention_mask, inputs_embeds, past_key_values_length
|
self,
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
):
|
):
|
||||||
# create noncausal mask
|
# create noncausal mask
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||||
combined_attention_mask = None
|
|
||||||
|
|
||||||
input_shape = (inputs_embeds.shape[0], inputs_embeds.shape[1])
|
bsz, seq_len, _ = inputs_embeds.size()
|
||||||
|
|
||||||
def _expand_mask(
|
|
||||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
||||||
"""
|
|
||||||
bsz, src_len = mask.size()
|
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
||||||
|
|
||||||
expanded_mask = (
|
|
||||||
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
|
||||||
|
|
||||||
return inverted_mask.masked_fill(
|
|
||||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# generate default mask based on input
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones( input_shape, dtype=torch.bool, device=inputs_embeds.device )
|
attention_mask = torch.ones( (bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device )
|
||||||
|
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# make square
|
||||||
expanded_attn_mask = _expand_mask(
|
expanded_mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to( dtype=inputs_embeds.dtype )
|
||||||
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
|
||||||
).to(inputs_embeds.device)
|
|
||||||
combined_attention_mask = (
|
|
||||||
expanded_attn_mask
|
|
||||||
if combined_attention_mask is None
|
|
||||||
else expanded_attn_mask + combined_attention_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
return combined_attention_mask
|
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
|
||||||
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
|
||||||
|
|
||||||
|
# gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems
|
||||||
def _update_causal_mask(
|
def _update_causal_mask(
|
||||||
self,
|
self,
|
||||||
attention_mask: torch.Tensor,
|
attention_mask: torch.Tensor,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user