gut feeling to change the attention mask
This commit is contained in:
parent
91ede71cf0
commit
6afc2b7526
BIN
data/qnt.nem
Normal file
BIN
data/qnt.nem
Normal file
Binary file not shown.
|
@ -287,6 +287,8 @@ class ModelExperimentalSettings:
|
|||
# "normal" will do the FSQ strat (prioritize midrange)
|
||||
# "equal" or "none" will set do no leveling
|
||||
# list of floats to manually set
|
||||
use_segmented_attention_mask: bool = False # instead of naively using a full attention mask, use one where each segment cannot attend after itself
|
||||
# this is a flag since I am cautious
|
||||
|
||||
# these technically should be as hyperparameters
|
||||
# performs token dropout to compensate for errors
|
||||
|
|
|
@ -140,6 +140,8 @@ class Attention(nn.Module):
|
|||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
# legacy
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
|
||||
|
@ -540,7 +542,7 @@ class Model(LlamaPreTrainedModel):
|
|||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
# shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker
|
||||
# shamelessly inspired from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256
|
||||
def _update_noncausal_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
|
@ -563,6 +565,50 @@ class Model(LlamaPreTrainedModel):
|
|||
inverted_mask = 1.0 - expanded_mask
|
||||
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
|
||||
|
||||
# some funky segmented-attention mask because my gut says to do this
|
||||
def _update_segmented_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
aux_lens, # (bsz, lens), where [batch_index, 0] = text_len, and [batch_index, 1] = prom_len
|
||||
past_key_values_length=0,
|
||||
):
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
bsz, seq_len, _ = inputs_embeds.size()
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device)
|
||||
|
||||
expanded_mask = torch.zeros(
|
||||
(bsz, 1, seq_len, seq_len),
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device
|
||||
)
|
||||
|
||||
for batch_index, aux_len in enumerate( aux_lens ):
|
||||
text_start, text_end = 0, aux_len[0]
|
||||
|
||||
prom_start, prom_end = text_end, text_end + aux_len[1]
|
||||
output_start = prom_end
|
||||
|
||||
print( text_start, text_end )
|
||||
print( prom_start, prom_end )
|
||||
print( output_start )
|
||||
|
||||
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
|
||||
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
|
||||
expanded_mask[batch_index, 0, output_start:, :] = 1.0
|
||||
|
||||
# apply the original attention mask
|
||||
expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||
|
||||
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(dtype=torch.bool),
|
||||
torch.finfo(inputs_embeds.dtype).min
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
|
@ -695,8 +741,11 @@ class Model(LlamaPreTrainedModel):
|
|||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
# use already crafted mask
|
||||
if attention_mask.dim() > 2:
|
||||
x_mask = attention_mask
|
||||
# because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch
|
||||
if is_causal is not None:
|
||||
elif is_causal is not None:
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
|
||||
noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values)
|
||||
|
||||
|
|
|
@ -331,6 +331,7 @@ class Base_V2(nn.Module):
|
|||
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
|
||||
logit_normalization = config.experimental.logit_normalization if config is not None else 0
|
||||
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
|
||||
use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True
|
||||
|
||||
n_vocab = 256
|
||||
n_tasks = config.tasks if config is not None else 8
|
||||
|
@ -419,6 +420,7 @@ class Base_V2(nn.Module):
|
|||
self.noncausal_masks = noncausal_masks
|
||||
self.audio_level_loss_factors = audio_level_loss_factors
|
||||
self.logit_normalization = logit_normalization
|
||||
self.use_segmented_attention_mask = use_segmented_attention_mask
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -1217,6 +1219,24 @@ class Base_V2(nn.Module):
|
|||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
|
||||
# create special masks
|
||||
# to-do, create it if mixed (although I expect this model to be purely non-causal)
|
||||
if self.use_segmented_attention_mask and not any(is_causal):
|
||||
aux_lens = torch.zeros((batch_size, 2), device=x.device, dtype=torch.int32)
|
||||
# fill aux lens
|
||||
for batch_index, batch_input in enumerate( inputs ):
|
||||
for name, input in batch_input:
|
||||
if name in ["phn", "text"]:
|
||||
aux_lens[batch_index][0] = input.shape[0]
|
||||
elif name == "lang":
|
||||
aux_lens[batch_index][0] += 2
|
||||
elif name == "prom":
|
||||
aux_lens[batch_index][1] = input.shape[0]
|
||||
elif name == "tone":
|
||||
aux_lens[batch_index][1] += 2
|
||||
|
||||
mask = self.model._update_segmented_mask( mask, x, aux_lens )
|
||||
|
||||
output = self._forward(
|
||||
inputs=x,
|
||||
mask=mask,
|
||||
|
|
Loading…
Reference in New Issue
Block a user