From 6afc2b7526a5f5303e05171f7f0a66ae527a997f Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 7 Mar 2025 13:51:59 -0600 Subject: [PATCH] gut feeling to change the attention mask --- data/qnt.nem | Bin 0 -> 3751 bytes vall_e/config.py | 2 ++ vall_e/models/arch/llama.py | 53 ++++++++++++++++++++++++++++++++++-- vall_e/models/base_v2.py | 20 ++++++++++++++ 4 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 data/qnt.nem diff --git a/data/qnt.nem b/data/qnt.nem new file mode 100644 index 0000000000000000000000000000000000000000..ba29df06d55bcc31c7399dcbd93754abec6b8629 GIT binary patch literal 3751 zcmbtXX?T^@)!l1th9rzK6avGqWAqXDR6s!{5s;xWsSGk%KuSRr87c~hAYEYlqrbm{d4m(LU)FrNzE<@X}lmSH7}&+ z*v4_uqoV(Ruc#tV&lx%-uP{D6H$N0VWN3DHP-=dDYGG9W{7`ysPC+<-XnHs*uwQ3I zREkY0t10QZibha5%d+lTs*)u$ReJ>Y$V_5H z;*N1awX9$FANt#%`kno2^ayJHXaBarZ~l9K3kS8*COkEuqFH$!MtQ{G8*XAftl*bi ziWhAw{*TXl2lV#tUWk9%c%R5~sEs3Rha+~_#-b&c_#2GIR({J0*v_dO??<>9DLlog z_#T5;tX(s$j-FaE)&}@8uZ<0tz(=Q3ywVTIcr)hVOMuM+o zf=}RA+{eAR#Lw_I2CU5&@H~2AzOV37uZ{=N4zYG!Cz$0Q@_^0tw!WWBt(yh-9GmZY z-wq=gZ@auEXZSp}@#Rj7;wyX&6>Oi~No+LQSP!hVXK@PGZGcU+Wt@e5)(E@oOO~}o ze#QFvCCuYR&SNtyz*{(E?O6`r>FF1}58F93z_a`qpK&AJ@)>?fxhe61{u!rmjql*6 zn8e2Zj^|p8Jjk^2wp^!tM>}TP44&jTuZ(_tgPS}VeSI#Q^Ai8TUQD;v*4@t|6JKB; z%5VoKdoO>Q|I5|Bjc@XdZ)9Eji`}wEENmw%WF4&v9%p51i$0$0ldL7)_HOp071;`A zaRQ65)LJW&wUx_zSw<&r&!3TE593>mw@4Q7M`b75SMzoGAHz~V&0N30kGvo2c&v32 zlJ?u<%401)$tiY{*D!_Ku!tld* z$7PftU;%#{a6)~NCFwT-VLs5>yZ~V5E<%@a_SsZ)fQ*L9aeSyua z;Ztm+5OtnSoQgsu+gV2*KxHJd7p^f`I39-|Skb$(GxuRJPxxF*LOfryKVuC#vOWjm z!W|oYKt44PBH9T}zqJO8^cSp&qS*pbz6@Ql(AKiM<~@YrR${9?&cDYO!e(7A_B89N z%+*)*yrZoCfG3!QJ>J1;c^psTgx|(}z85jVMSs-923F@Woh6!=@f8A8qUo^X8;XF1jeBg68Tj|iUTQm>=R+(Ga=c7?FWr@2UM@+s53y2h)q*cS_*x9pTS zrJZeND~t8b=#EC5huf$nj2@N`Zz0+G^OBvnSv-QHmSrh^84GxV-{PQ6=VIM!`6P^I zS)8=5Em!rIDJv;rqHey#S8+BE;+B13Pk2{s(+P(2M|;Gd_HD>uTR$#ZZ|P6S+l?H- zSR^oypKyuzZUv?a$uoH#d-xtl;~r(@C}Q~&)^jGW+b{MxfUyQuOiv*`1)tkql=riC z)*iwcG1yr@GIA zZ4aV`ownm>WMlO`jYXWP_0Rb!uH^SHFXA+;)B7~WiK|}rgM0>$atLl)I@SxFT~Gn9 zYh@uS>G~Ygtc+rtfCya27ThNcynr~|#1G28q0rZ<)+3OhnO9}C7@uMzcA~a7z&hnE z0S)+*eZyx}{Y$VHSL_9*a}nM|Qz=0aS1{6w#2!W5s_|P^4f}+*M7FeM!b>@!;EXj9 znpP=qljYS9_MBLHH%9YSyMm8VhtusrP}PpwG%=lr@`BK}ia#r#r9#6Et0OCQteP0Jm7*Gn*>=okii<1QD9z}BX1-9!nvVh6&Dh%8 zV3*>4RAWbw#VnuCzuOx7J$|uHJYolYjIy;%Rb7N{q~iNLLfJp3x)~{}N1347xvnaY zL_=Lq;F#DePHY&)NM&p-Z}NS+Y8~(*hM}G`_ZZtt-%g_p_Vb2v-d6WA-UwM}!%p}U z;5Vdu)j3`)K26BpAU*4ZZkUfjc8=|}E=8m8cpO9@G*k7@m39?djCze`iuyP^Bn;KC zDOOq9)y)2hQ=}K#68qC#7Md@nNZ_D5TH+PNNoBUa)` zW%?p4M7Fg3e2r&OYE>EM%~4$}K8cAo+TQk`GhkYVt|UwZHPX6n7=F zxm{@QEsb_@LyT;FAPpYLNs7^IRozjH6$AJ1!K(SbY`_gxEUw+B69#-rJyE18Hj%GL z@ptn>8s4)$QqikYpt92cjjW8rUM!#9^hOF>*Dxj zb{eZu;4Qoc7J90+@ebnB{}CfpP=0nvXE*Q-o#p@!=ySI(^g@5wn`r)Q-{M#Ko|xb= zJE-4?mtL30admvnrA?C+$(lmqZ1$7IFWJ}k_*wZ^MLpqP(Ngu^n9=G{dizwrms^l7 zogD4G+10mTs6H>Fo~r+0v0avBNI&1TGyDVZla-oMfHN4(z}Mq@YbrdPri%^ES*|$v zadia?ygNRU){jwj_3*J$(Nd;L32xYQxYfgc+iM$8QC-DM)E0LX38M=Y(-o}2O?HB7 zq|1+Dih8-faG?4&)q^^?DtJb^m8mXq2EUe8J*aA)$o0tgE@JOq9>uUaWlqLkp{}D0an1T0Puq zsRQ3DR1VQSSKVAkSw3v_)q@1J!@<32DG~Y)Gb9vF%}5QW26aXSbz@ViXec*-Q0AbV z)a?G*p`6TcR#30=@>gLg74%a<>X5wbQ2+eYa45L1Ysa@!%4jeg8W9fePpPY~$vNqv z=+w;AK{;{Jp%Lk!Y;6&a9-f;&I6kO<*Ir&$Zcb>3{x(IYG|__LyT|V;%_)f9Id$iP z-Q#z@p>=mY1{ds{mA!M;owa(LvwN!6H@Le#J2fYBXliCCXn6OZTPP=ZVCc})aL_0% F{$Dk<&p7}9 literal 0 HcmV?d00001 diff --git a/vall_e/config.py b/vall_e/config.py index e4c29cb..722a272 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -287,6 +287,8 @@ class ModelExperimentalSettings: # "normal" will do the FSQ strat (prioritize midrange) # "equal" or "none" will set do no leveling # list of floats to manually set + use_segmented_attention_mask: bool = False # instead of naively using a full attention mask, use one where each segment cannot attend after itself + # this is a flag since I am cautious # these technically should be as hyperparameters # performs token dropout to compensate for errors diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 88adce5..d27b627 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -140,6 +140,8 @@ class Attention(nn.Module): self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout + # legacy + self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads @@ -540,7 +542,7 @@ class Model(LlamaPreTrainedModel): # Initialize weights and apply final processing self.post_init() - # shamelessly borrowed from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 until I replace it with my own noncausal-mask maker + # shamelessly inspired from https://github.com/open-mmlab/Amphion/blob/main/models/tts/maskgct/llama_nar.py#L256 def _update_noncausal_mask( self, attention_mask, @@ -563,6 +565,50 @@ class Model(LlamaPreTrainedModel): inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min ) + # some funky segmented-attention mask because my gut says to do this + def _update_segmented_mask( + self, + attention_mask, + inputs_embeds, + aux_lens, # (bsz, lens), where [batch_index, 0] = text_len, and [batch_index, 1] = prom_len + past_key_values_length=0, + ): + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + bsz, seq_len, _ = inputs_embeds.size() + + if attention_mask is None: + attention_mask = torch.ones((bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device) + + expanded_mask = torch.zeros( + (bsz, 1, seq_len, seq_len), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device + ) + + for batch_index, aux_len in enumerate( aux_lens ): + text_start, text_end = 0, aux_len[0] + + prom_start, prom_end = text_end, text_end + aux_len[1] + output_start = prom_end + + print( text_start, text_end ) + print( prom_start, prom_end ) + print( output_start ) + + expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0 + expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0 + expanded_mask[batch_index, 0, output_start:, :] = 1.0 + + # apply the original attention mask + expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len) + + # invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked + inverted_mask = 1.0 - expanded_mask + return inverted_mask.masked_fill( + inverted_mask.to(dtype=torch.bool), + torch.finfo(inputs_embeds.dtype).min + ) + @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, @@ -695,8 +741,11 @@ class Model(LlamaPreTrainedModel): if position_ids is None: position_ids = cache_position.unsqueeze(0) + # use already crafted mask + if attention_mask.dim() > 2: + x_mask = attention_mask # because we can attend to both a causal and a non-causal sequence, generate both masks then pick among which to use per batch - if is_causal is not None: + elif is_causal is not None: causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions) noncausal_mask = self._update_noncausal_mask(attention_mask, inputs_embeds, past_key_values) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index b586623..97915e7 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -331,6 +331,7 @@ class Base_V2(nn.Module): audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto" logit_normalization = config.experimental.logit_normalization if config is not None else 0 per_level_normalization = config.experimental.per_level_normalization if config is not None else True + use_segmented_attention_mask = config.experimental.use_segmented_attention_mask if config is not None else True n_vocab = 256 n_tasks = config.tasks if config is not None else 8 @@ -419,6 +420,7 @@ class Base_V2(nn.Module): self.noncausal_masks = noncausal_masks self.audio_level_loss_factors = audio_level_loss_factors self.logit_normalization = logit_normalization + self.use_segmented_attention_mask = use_segmented_attention_mask self.sep = nn.Parameter(torch.randn(d_model)) @@ -1217,6 +1219,24 @@ class Base_V2(nn.Module): # right now limit to new versions because I need to retrain the model for noncausal masks... is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ] + # create special masks + # to-do, create it if mixed (although I expect this model to be purely non-causal) + if self.use_segmented_attention_mask and not any(is_causal): + aux_lens = torch.zeros((batch_size, 2), device=x.device, dtype=torch.int32) + # fill aux lens + for batch_index, batch_input in enumerate( inputs ): + for name, input in batch_input: + if name in ["phn", "text"]: + aux_lens[batch_index][0] = input.shape[0] + elif name == "lang": + aux_lens[batch_index][0] += 2 + elif name == "prom": + aux_lens[batch_index][1] = input.shape[0] + elif name == "tone": + aux_lens[batch_index][1] += 2 + + mask = self.model._update_segmented_mask( mask, x, aux_lens ) + output = self._forward( inputs=x, mask=mask,