From b52c5c5d80f3f985e630ca94d642288e566ca7a7 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 12 Feb 2025 16:16:04 -0600 Subject: [PATCH] this seems to work in testing --- data/qnt.enc | Bin 13073 -> 3353 bytes vall_e/engines/__init__.py | 1 + vall_e/models/ar_nar.py | 92 +++++------ vall_e/models/base.py | 324 ++++++++++++++++--------------------- 4 files changed, 190 insertions(+), 227 deletions(-) diff --git a/data/qnt.enc b/data/qnt.enc index eede679951856ced7a862b34c9a551e191fa7edd..7905f135a458d919ad9b57ba2630b4d15114d82b 100644 GIT binary patch literal 3353 zcmbtXd3=p$7Cz^BVrd#d)tYLfRBW|mY-34BDJ?|>6*D1oC6^eoI54f8`dr~Y)tR4CZS!(CplYgAy2*BzVWN%478V|{^+Jew8hiF-EcXb+j_ zptBBfnCS#WI%OL4(1yQHit~+1b_Y6TqaY%^+dAbYXJtC&Bc10O^k9@4Z;TVzt5>h? zPwiLQ^mxPPJU?XeASWopsn8&*j7^D6GHW=&@e*Gusbai|sTAo{HZex!d#uiMUWk7t zAvN5o;+ODZq*L|35;{25{#Szdoa#d+zj-K7#KrSirvuqmx63ZQuN84$F7dv6B5m!h z8h^lSS%CJ|4#HJ^Lk8&;q)3X?VrAY!0jA1Sd7vvemQ#@>-{7EL(YFwY2bl3YJ=EEF zMj>pB#^}Hf+|T!<1W%35@+`NqlP0mIJkOJwz=wSH1es42-;!WMEr}+_QV7^Atix|f69Cpf12}F71 z$_OpyKDI(<2JxheeWZgERx9A-Kb`I%(0hD5_HQAm&oc`U!_W~NJ;%}X=Q{WO`A`B#~T zU`)ZMI#t`_F)ztemhkU734Y4u9tlM#4saTlYnHu*VxaLSm&!AVMiZn4YvO0TgBMlD`Fk&@Ww`++8qK@{{Zk%8e#__)1 zHUC8*U+Tic6l(|LkhGIY9EML!*I<<3o@VQ2xrf1K=Z##hxtxSO679D-8p|?mERDs_ zn}lU3;!GUlK@8<*>NVf5Ml7d@mnRq_C%7J)Bn0F2knR?hGn%AZSV6BM4!=k^=P((I zbPU$>kwyG%^g%RFVT}2?*y4Df+3B9#lVbbwvRJ3fDb6)}^wafNj~j9jpG#d?rGwZ( zn=ymOWEuX%={U_?{%V@FVk<7Q?_3+pB05PM$puM}HzduZZjwFX#VLHp z4d|!WIg!!qsy~RoPsSzsX|MG_GL}e;BwODT{{E8Le-184cia0B&zeTx=uSO^*QKLI zaEFvd1WUBDWXoGx7H`>@|2Zv{P^1HOnQ6J(p1Ei#}r(-skpq?CO3z>_f zdWO|?9y`l*?vfZSl2EqPU6QT0#ZSGpGjHiZ*`r@^seEppiohw|rd5!K7jO>m$gkR! zUuZ3PU2}B+Z(^*DHGkaIB8`@F$i@NLEz|W^4VRm^z}qqkqwIuj*vKO6F>Xs&9=3Dj zv4F>A2p{m0o%;dax3_=j9R7{n?bK0{Z#oUocv~H-3+R(vjn(VL@AIW$(3!U)l43dbQ!CA<&yej0D~EcsaV0g(pUSa$J$gm!Xt)) z=gV`l9Br318zOeWE{EWGyir=kOjM=?LwrXC%mY$L%DQQ3)e0|7OTO z36Mf*j>>$9P2Z&1lv@06SPVhihG3 zj>UQrZ|fp0lwzz>Kh=N>dE&D6xQ@p)L$+I(C$;%RBH0dqWF+^QPs8oGij@enY^h}V zm~TeE{|vik|diMHm4dgG(*SPvj*q+PYKfI{F>im9Lhl$W)+|>(~*ulY-q@w zj+KT-4NbR)xXuaqQa9oRH<-j}>ctu!qYv5ainh`fbx_SRE?T$B60VRPnrQypfNyPf zf1Jfy^w$C@)C+PL6=kRX$SmzA9r=n?)@(cPi*`DXY(lExM?bW*I+ThjW~(gDGf&nx z^!=Cg;9?W{^(lh^G0RjhW+*E`l%!UR6D9sbQM zt1OSq0!3UT7f@Yyn+DZosMNzpwp$RdY7-4Gl-`TCCByCzx6Q*%^=-|RLxxd1nQt|4 zf@SSdsV4>Kq|khBl{!+NGp)9sl6MWg1{lJY$UH;uHWVqt=B8mAW?Fkl!|kRJZWPP_ zPT)24w>H>D)>#bi(nQOh2@J<&tJiN@wxnRVonev1{j127RT!ps%wO$L%QEZ?X6kBm zWMvsB?erV@n#at49WB1rY6j29LY*!72$iq&m|?{>+{RqXkvmAUc>Wj>h~g$&Rm|3S z4XtIf_Q6K^UM}GxM#Cpxu?ZAy5Fr7$q$Q?Pf3D+rbVpsw-UK~voIA!FUrMJ8$=aHf zFvE=^HbA!e^J%WGLn5B&HP|`vo8)M3^t4*ETHn)!R+H!8I{IP~=gSsDiZp44aI^75 z^STRnBvKb>j%?Q~ZNbNKpNV)XA7H6|XEC1CWnr8lzGvJGW% zkOz61)vW3q)1!J!x1qgmWK$fGM%*hqBm*DV$(zewsVvWN7WznmVP`JS@tO|SnpzbP z>^`>#6|h5RSRHwl`w^m5B~ACY^Tj(abzeJGCF*$_N^vD7C%9w0 zE}z?}-LuQ8r~qU7-042&<){#Q?d^$mhq{KjhI^WXy3=Fb3ATkVbYzluL{sOLQhCYo zNgj8iUDQLP8rg!8`?B^<^Q43p&MI88FRO5#t@C#yQVMet3UmBx?Q{OVS+>4T>G}kh bXIQFhnA>@^^uL1JU(yLOOEHgT=@zygJS;w{_?)% z9TR)?NbXm$V^V5-#U8y9(q4~CO^xf9w_R#{$E3ubX{o(BrsYkMXsf)5ouXe;w1Ess zv1heIiVsW4)imX%at*Zn^+ui2`X$HbP08KLM5OT(DK`&KAC{7*Y053-8h9(EZeU8@ zMvWTP|A)WGtwGJSlv_Ite=Q~7pp@InwYrHW9h0C&O8&0WH8N9yu1r&~Y0B+PL&@{v zRU4L4sB3hJUa={6geeL)O(}9?il%gA#}Pp8if$ih07=Nnjsy1m+xkz4vW zem7YeujI+vMoQW+ZEv&nBU#OrR}#Qtk)5D?se635PPLnDs%_C9ZM9w2)z;o(>?ix) zAAzlvBNV#As9mfu&^MPEdZ&i#P3qFY^`Er1HJ9r4Iq+Qq zu1nrl=gDZ9Y0Esp*4ph>Uh~Rot0RwVy5w>3Fy?U`D_87&+pGbL&3d;-1W89qAK#~U zS}V(zG1A_$y`?>5+4`EL+oN22JzcIyLwQ>6w0Q5WnfkP3%6R`l3))F3AW70llKn64 zi}eQ!_Y6H}b998?q3J-^3;BrhApvO1$qsljQG@i|CBqo8mHoxY`JrNN`CV2?f1d?) zrc1agArF`QL;qB3X+LXfCw-5Edp}>Mj{!-hCP}&tw6h-Mr3&|w{-tixsyae0>t(Ga zzj-am_5;#K-^bx)|3HutEXxu-{f<>fc;zd0b6Y#P<=zXAUy%Du1H;=-EJ)^y=?^WG}mSNyZl>6 z`bW|L-qexqc8zN^+#fA>N{&5a9|1)_Pltm8p>!DWh92y%^T%?S%}ot2KQ&Yl<$v}c#lUTYwM%7gfY{R(AqKx9u=mhfE@GUR$U*L zhkz^9R>=<!#+qgGZ18T!D#_E>vba?3z35~08X z-H+72EVtY55@j7cmEsbb zBUNmzBzhmcN&DLv$s-l4oiDOQR!mn(6KU*i;NJ8|c?1coVbf(ZFa}$4*lIz89i~{! zsO3sPwt@zS-yi`aBHY@`J=leC5B4rO*bxim~Y)#`$lHIp}yz-69)hFjgZTs&}(t$iUa=+edERMvG9gBujph>Qdi|`7ypO zr{%+PLgMA9-K1~$Zh6bwcq9A8>VZplErX8y=*y%#Gmq0(Jrg+Uc@eFs^&(hAVa%j$ zjE&~&e<2zQCQ%q;tOwYIFQIiz1h3qD4Ru)(?xAh8#hpIN{y|m;*gha!EulQY?*Ln6 ziMGqy=pWilpYddAri=7#U9Xw;v8Cz5a$39Si(pznk6S2#qQ!lMm(#M)B#aiyAuL6* z*8_`JB%j>p?|3eKO4B4uS6Q6DukY(>nS}J6)dju)4gM6#Y>KvJXdfwN8L~t&WkzsC16;rBW+Z4Jqur-LUryU(siZMb?3m@lH)x|X^&z-l9a#R*Z#+&? zZG_bapK*E>=^pJ@t))zoUv-(>ri)n(hWQvNQz!?XFu3K+C0)eDr28s zm&ZVXglTYHx5^L>8O@Ai**%diX}8I3NS^(K0fr|A(3)l!OJy8_RHZUmc^3CBYz z60yR;n#I^B@URGNu43&nq5D;Z*@hX`dOQZ*D`@Gp%g=NQ~g)X($+d2 z9k^Lf*cq>54VdvPUTUUS)+M@7$NM1WniKI*PT-ZVw@Q8%E4QDgzqp^ZosoNheHNM( zaR1&5Q5RqwX=AwutOFi_4grTk^a`yg82U=!dg z1!oIbb1fn5(Tn1ClWdb5sch$TFIP!9r{^?_6{@`t_V(TySSvAJWu9kZ)&Hfxd3h-> z!hIhQ$N?{5OF}N|EN=upk4NnR`40N13e6K@wCsg z7MA0ynXQE7NB>i84Xem3orn)S%$EV#37sP|>G2E_upTNu?Gx>k?T|#G_HA}em$UXP z_t{p-jsnk7n~0WA@E?7M7J>37bd5ZT+#dCbvdrGsqjnKF2<20+3+$=Z%2VwaPs4ng zKJGW^Vsrjld9qc(zdxvNY7UfZjZI7SwYJxTA0sErSuM|54au*CY5SAhstbr> z&dGFLY9DJIDDu73(MK3P73+Glj*ReNtN+Z3(Aft-x6#l<Y#Stt20PpkuV(B%=x?nOIEY&QYPsbo*t7$~v`npC#& zdeL?gq5RFPV?DpVV99X54f=MQE@Ca;2PS{ZNv$Cbp+lhKQdtTmnGc$rk>c9RkN6%+ zFE6Y0tmI_)UfGTWZ0&2j?Q_yXf0FY^)mA;nT0ILY{*C^$fe&>w z)rWxZHEDySj3eeaEfuXYG2v^{$>-v;Op|jm3k`2h`;YXPW`lJl$B)8FwL$}CVOcUG zlC@q7A(5f%r+sFmU9e5-BbGF3UBlN*TAjsW%)+Xzr+;R|b_8oQmT~s;8?0S;3uAwC~h6ljvHk{&t{E1snE}_s zC>{JRYLD_39eu_1+8c93Yj^n$nQF7`E}d_;Sb}Z#CioI_rK&AwC0mVdE!Wjv*LRZ@ znk3KLL0cz_h|>xasU9Wv%proh-)q>beuQ;9K(I>+*<~L`-w~9{c+eAkTm;?@z6PkH zu+^8ls0)|`IO+mhs1L2f7~!oxziT4jci|C7p}lgw&jh3`)Qzx;v>ah${7ED?3(Y#< z<%sS(%K%@6FFxBom%KJsOWHSH!TXY>m?8gIY3*bUw6Cs{r!|>KtUZ3?LVsQAv-a$g zQZ8B@-+Zn;>>2pZ&)OFG6@2Q;2n)xuPhZWC6f>skZu3?NC`0mfj zSY6GkUfsqJ6;!uO^yGC`jH|YeaoUiz%hZ;-S&GVLqK?bw%B_I%*Yr;$Fv^dkw$;9qr{raOLq?%{>*S!Xll9UT z3s?iKf7w4~HQE6mE)!jElzHT;8j({dpQRC$% zoumiJfdsip0} zNG+tX*wzY2$0#eOA znoqW{QqS|ddHNpBIKixskmbulF9sqLV;OHVKHwa^gSfLW5#d0&gzasv*fIJ2smysqst%M`5*-9;hM%1x3tPAfG_bo^MFZvW#fGhCvAHRl08v#7= zHkml~p#AKnne$2dmVpXsz%@gTcoScQOx&prtQa)hU~_$*ZO4u`_WE*KCU~0sjvaU% z+2}x7L6)zZ7pF%tpqn7MNdIO^eYuP+ zX@i%w9$#ZU-0K7%Ka-i{MmG5wz8>;|c+RE#W}v=oLu{>Xlw~qXHiPE}T3t%8!5#P#tYD~T*2q1~7VKU@sFHa7 z`d0t%bNDipE0ludEu!~C8%m$>E!uj5ej^31*M&A=ge;rD{Y&mQdVV8qBmF~tXpiyo^12%v@G#^tgm&_yZj8X&sB2Gu}H@~SignZ z8auF7*LWO|CfZ=nveN#W{@Zu+bV=9Zlcsur?^~jSP_nG8mxh16SpAYDK-vdrrqlXc zoq*g;;2wo=Jve89XTUbpg{Oc|cn@>+p=BB#MnP*LuONe0Js0-AroAR&S=v!|>uK*} zaailWkc$V13b(Na)Wu`?Q!i>|EoJ{>$K-8ryC3g?COb(bP|R9AtgguluCdHn#-^m+Z5d}&{M z6hD-87jE<+V);hCf%bEeoFHF;J5BKN-n41>nz33^i`gP+EIY{Mtd$=%-PgdiygmZH zt=E_Bdymy(a)-rYzlZw-8%67fdEYJTts+tWI$MeCOo7K_7g%$9$xLai({v^N&OH|6 zJ3JeGSnHRdd71}#UE}>_o0r9F%m#yFE`5-E-W|3Fs$N7^ z^5`w>GqlsMe3h)QC)x4(2b{}d57XH@$t7L+_Lz0l%hF#rv-`6ZDrT_*b=4PoS+I+i zi#kf)V!vUEEdi4qe0_rqV=VfXPyduBwG8m|Vm~L=?zWZCubv$N2F0SEBI9>}^}eiC zwSw{|J%1+C@gUi_(7OoH)93OzQRBNrEyK|5VUhp^dqC~BK(Z2?4(s#e-flqBj5V&1 zr|6T&R4sb~xts=FJ|%8_LOx>`>!>zWv%E;bI{y?cJWS*pk3=1SE2Vusb57S^*>AWP z_)6L#c16BZWEw0s>&^O-zHc>H?P`)ky9iHL*n>97o|z;((GBd`868x6dAJOehSU& zFL%j2TZ>GM^8cYZa&oJs)YA!dtZMPU5h6 zcz;(cvzQsD-Vz)_tEKq;)hk1djp8iS9mTbw^+}g6^n?oG`IDEeiOY=K+ zolFiH-QwC+a{O)C%ihsIy<+8MB^jI9K0NRFTmGL|H-vV%Iq67bO;(FZ8=(4CvTD}F~Z zefNTMUoEao?fAh(*%2B~rmZL*PEi|V zyVIGWFBj=yd=05cDpndK$^36g4z(Q}~xt(8?o4XgPj{B1AyGSW82>*U+m9efo}c{973 zPa!||NDFp57C?h3?74qS9-iQ_MIp)eBJJZ0YvQIeGZ^d30&1xf8;Y$%s zX)AT)H}WfWt*#AV-@6{|ixI63A}@ZkXR=ed5dAJi*868PuCC8!fAbM`9Y$bL9wmoT zo7lTBdGh6W#^>2N8cGf@j#%(vV0(}K!td?4PKGOmKF03ei|qOw#}k}Mj?HTgyC0YB6v0FXW_mW+G*ZaV+$MC}bWQ|?KdpZ(wm5B9mtaf+u zHig)4ogy8`5=~-9XM#?J>TTpc@tkAI7iex=_R=x30>wd{5O zYE!tsY_(bE#(N&ttT(JL@;}VRd62T*#O)9J2C{g0WU3eT(_~H#fX_w&YVYtVw?@3w`<%k3G-4ziBgT?)Q6R@@+F^6=xa>lP`J_jSD!oW8{~h z>U`@ewb0X2oH5ysFK{QT%qPT@59n+-GFJCf%1dpch^0u#1}&<)EuS4Gzc~iKDoGCD zrB&eTr}{0re-$ntz)~DQ&K47gHN=Ya#aev?=O4GntaHRRW=LNob_({Syi8$Lk6N>U zQ?Q-E4uxNAmZd)|4t6R9TbKgJ7e{0#YIT>;8%s>T z6t9M-Jed85lb-AUlor&7n!)^r`dAA!(YO4OcGWDDqq@~wz=zoDqugkJR72kuT)7x| zFrLO%p0@ho|BMoyuO2WzU@dLA)Ky3F=E!k0;&bBs=k&M~)3>~@%(r#;mGh*wyyG{q zYSq%SXM6?vP}f8!QRkbkO1RkpbB&T&HX&^odlk0w$=ai{C+>a zLp96by#=)ye5+Vd+0*>QeIU(iok3)lJ?d)mJC$I;-S z3+1Hwbg>Ed<bIqp@x3Le23Y?F`=-{HUbTgXBiKt{>CWL+|J7OU_w$%Gzb$i&?POvj_p|bzxBb|QZl1?qC5l{# zl@jAdJYndbXsA5rvB_)x;xo6({kBW;oT!$9)`xtQmP zUf=QK?Eeh6{ZyXRxt@0@RuF4ML7Vxg*|6UK5Tt3ce zs5gmPCLEAX?*CCP3@FL#i98-Hv%{0b$KWqEsUd|w$WvAgWuWvm#{V-5E;-7Z4+T;j3 zBf~L3Pz|ZCY9*jas^3neCqR1;pL`f`?-cvZhm)pVL3vUEM5d%?qj)i1seDaGp%HvmDWc}lF542pW^Fw zy@1aUW?N|$*{_QQhJfbZVk}b{zpV6F(s41;oyouUOew{zKx`Uy@r`_78}K+T5j*BWhkG-A20jui3o%=L z+7y>Jy#%`go3*_CO8)g0)~jqe$h_I?m5e8n-@ooFHre=H3y1iH?b>FD-Kdk4yI z#urWp)=m1AKA}U%Qhh)Sn?p?h2kYQ|s1l3+xzcuU9-$Zf{TfL)YD4`i&#|lIInwb# zhN$~jWY0gwzn(@^zJYz-TgmvfBd6L>Izh|nWXMVa+fSVM8iR#@MHeFlWM~N8~pbw!0p_>v^&-3%nU#`!B#e zUUphP*(LpmQ#uiA@1(9WyQv%OE9Wk-4MPlLu)>iT-qMnaf-ae zMV;!SIn7mvQ$;}vmTEs~!U|di>ThP;73`#J@k+XuGqAt23)qz0aa(&rM(|~YtYJ?e zg{*vT@+HOWNA@2sdPO);5wD@Go#M>EJy?;$+Q`dmd2(nUqZi|_utRja{~?p@G09;s zdLq{KQJy9tJqxV5XIi%BvR~LYuSq<#j#Z%|8krj$^CCGVX@8cJD;-%s+rqC*{n;P2 zOZG79#@j@G;q-5g-7O=D%!Y6_YbiVXi;<=>{xX@Jd9nj5^Ar$f`UfV~Q}S_kE}wO> znppgcNc3_^M^n0^A#byn(SdpL;s3uS$2l=Ll>A^U*~klCMmOpio22!%p8QRQ)1@$-_Q%p`WGO`h;sACLWM%pU##&bCb?>vPF^dQ-aZU;LZ%6!Qk=^JciO0l$*yE?SW zgCo=DV={M~C>-n1rN`-DBcCROG=E1qb4yaEhafGH7#^V>-su_#P>|Am{KZ&QgYX%#P}ZkGek_Q zZ-6OzWBP_E8%K5CFu6N`B`0i5e|_Vq{xLl^ruW)7s&`EHK760sJ>m73v<;JD)-L3h zmi^Ag^!N>Pdvw_FVGQjzjNUk^d(7Hd{ln+EF&pM~rjF5i#$;#oj!E9IY{R6`YTUXJ zonol($&5+a?`Dso_kYIhurWP5BYRxp51BDxxDCr<5;taqClHCwuwhwRWV|p32!$yE kPWMJ|N{CDB(kre@e9C>1{}G8#Ou4^Tueh|7(j6-PFG%nYH2?qr diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 2ceb836..831e1b7 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -397,6 +397,7 @@ def load_engines(training=True, **model_kwargs): key_name = cfg.lora.full_name kwargs['name'] = 'job' + kwargs['resume'] = 'allow' if world_size() > 1: kwargs["group"] = "DDP" kwargs['name'] = f'job-{global_rank()}' diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 654cfde..afd1e1c 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -134,13 +134,11 @@ class AR_NAR(Base): # trim resps to only contain all levels below the target level if self.version < 7: resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] - elif not self.parallel_decoding: - resps_list = [r if t in text_task else r[..., l] for r, l, t in zip(resps_list, quant_levels, task_list)] # tensor to cat for RVQ level 0 text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16) text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16) - audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16) + audio_stop_sequence = torch.tensor([[self.stop_token] * (1 if self.version < 7 else self.n_resp_levels)], device=device, dtype=torch.int16) # final validations and stuff for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): @@ -173,7 +171,7 @@ class AR_NAR(Base): resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 # only apply stop token for RVQ level 0 - if quant_level <= 0 and timesteps[i] is None and not self.parallel_decoding: + if quant_level <= 0 and timesteps[i] is None: # append stop tokens for AR if task not in text_task: resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) @@ -1103,54 +1101,56 @@ class AR_NAR(Base): # is NAR if (len_list is not None or resps_list is not None) and text_list is not None: if self.version >= 7: - if self.parallel_decoding: - return self.forward_nar_masked_parallel( - task_list=task_list, + return self.forward_nar_masked_parallel( + task_list=task_list, - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - - lang_list=lang_list, - tone_list=tone_list, - len_list=len_list, - raw_text_list=raw_text_list, + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + raw_text_list=raw_text_list, - disable_tqdm=disable_tqdm, - use_lora=use_lora, - **sampling_kwargs, - ) - else: - resps_lists = [ None for _ in range(batch_size) ] - for level in range(self.n_resp_levels): - resp_list = self.forward_nar_masked( - task_list=task_list, + disable_tqdm=disable_tqdm, + use_lora=use_lora, + **sampling_kwargs, + ) - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - - lang_list=lang_list, - tone_list=tone_list, - len_list=len_list, - raw_text_list=raw_text_list, + # NAR demasking for all levels + """ + resps_lists = [ None for _ in range(batch_size) ] + for level in range(self.n_resp_levels): + resp_list = self.forward_nar_masked( + task_list=task_list, - disable_tqdm=disable_tqdm, - use_lora=use_lora, - quant_levels=[ level for _ in range(batch_size) ], - **sampling_kwargs, - ) + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + raw_text_list=raw_text_list, - for batch_index, resp in enumerate(resp_list): - if resps_lists[batch_index] is None: - resps_lists[batch_index] = [] - - resps_lists[batch_index].append( resp ) + disable_tqdm=disable_tqdm, + use_lora=use_lora, + quant_levels=[ level for _ in range(batch_size) ], + **sampling_kwargs, + ) - for batch_index, resps in enumerate(resps_lists): - resps_lists[batch_index] = torch.stack( resps, dim=-1 ) + for batch_index, resp in enumerate(resp_list): + if resps_lists[batch_index] is None: + resps_lists[batch_index] = [] + + resps_lists[batch_index].append( resp ) - return resps_lists + for batch_index, resps in enumerate(resps_lists): + resps_lists[batch_index] = torch.stack( resps, dim=-1 ) + + return resps_lists + """ return self.forward_nar( task_list=task_list, @@ -1254,7 +1254,7 @@ def example_usage(): available_tasks = ["tts-nar"] model = AR_NAR(**kwargs).to(cfg.device) - steps = 500 // batch_size + steps = 750 // batch_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 88c68ed..c401f08 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -245,21 +245,6 @@ class AudioEmbedding(nn.Module): return x -class AudioEmbedding_Sums(nn.Module): - def __init__( - self, - n_tokens: int, - n_levels: int, - token_dim: int, - ): - super().__init__() - self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)]) - - def forward(self, xi: Tensor ) -> Tensor: - x = sum( [ emb( xi[:, l] ) for l, emb in enumerate(self.embeddings) ] ) - - return x - # time-step embedding # for the NAR-len, since it probably most likely requires encoding the timestep class TimeEmbedding(nn.Module): @@ -318,10 +303,6 @@ class Classifiers(nn.Module): levels = [] # map names to levels - """ - if names and not levels: - levels = [ None if name =="NAR" else self.names.index(name) for name in names ] - """ if names and not levels: levels = [ None if name not in self.names else self.names.index(name) for name in names ] @@ -341,9 +322,36 @@ class Classifiers(nn.Module): ] return torch.stack( xi ) +# naively embeds each level of a codebook, then merges the embeddings with a Linear +class AudioEncoder(nn.Module): + def __init__( + self, + n_tokens: int, + n_levels: int, + token_dim: int, + ): + super().__init__() + self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)]) + self.proj = nn.Linear(8 * token_dim, 1 * token_dim) + + def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor: + if dropout_mask is not None: + xi = xi.clone().detach().t() + for l, t in enumerate( xi ): + xi[l] = torch.where( dropout_mask, dropout_token, xi[l] ) + xi = xi.t() + + x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1) + x = self.proj(x) + """ + x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ]) + """ + + return x + # Pseudo-MoE by doing additional decoding from the main transformer's last hidden output # ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment -class ParallelDecoder(nn.Module): +class AudioDecoder(nn.Module): def __init__( self, levels, @@ -356,60 +364,39 @@ class ParallelDecoder(nn.Module): attention_backend = config_kwargs.pop("attention_backend", "default") gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True) + config_kwargs["hidden_size"] *= levels + config_kwargs["vocab_size"] *= levels + hidden_size = config_kwargs.get("hidden_size") vocab_size = config_kwargs.get("vocab_size") #self.d_model = d_model self.vocab_size = vocab_size + self.up = nn.Linear( d_model, hidden_size ) + self.down = nn.Linear( hidden_size, vocab_size ) + self.transformer = None + """ + self.transformer = LlamaModel_Adapted(LlamaConfig(**config_kwargs)) + self.transformer = ml.replace_attention( self.transformer, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) + + if hasattr( self.transformer, "embeddings" ): + del self.transformer.embeddings - downs = [] - modules = [] - ups = [] - for level in range(levels): - module = LlamaModel_Adapted(LlamaConfig(**config_kwargs)) - - module = ml.replace_attention( module, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) - - if hasattr( module, "embeddings" ): - del module.embeddings - - if gradient_checkpointing and not module.gradient_checkpointing: - module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( - use_reentrant=False - )) - - modules.append(module) - downs.append(nn.Linear(d_model, hidden_size, bias=False)) - ups.append(nn.Linear(hidden_size, vocab_size, bias=False)) - - self.levels = levels - self.decoders = nn.ModuleList(modules) - self.downs = nn.ModuleList(downs) - self.ups = nn.ModuleList(ups) + if gradient_checkpointing and not self.transformer.gradient_checkpointing: + self.transformer.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) + """ def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor: - # split into levels - if level == None: - x = [ self.forward( x, l, **kwargs ) for l in range(self.levels) ] - x = torch.stack( x ) - x = x.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification ) - return x + x = self.up( x ) + if self.transformer is not None: + x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"] + x = self.down( x ) - # do one level - - # attention + feedforward - """ - x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"] - # this really hates an output head, so just treat the final output as one - x = x[..., :self.vocab_size] - - """ - # downscale to head's dimensionality - x = self.downs[level]( x ) - # attention + feed forward - x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"] - # upscale to vocab logits - x = self.ups[level]( x ) + batch_size, seq_len, dim = x.shape + x = x.reshape( batch_size, seq_len, 8, dim // 8 ) + x = x.permute( 0, 2, 1, 3 ) return x """ @@ -572,7 +559,6 @@ class Base(nn.Module): self.causal = "ar" in self.capabilities or "len" in self.capabilities self.version = self.config.version if self.config is not None else 5 self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0) - self.parallel_decoding = self.config.experimental.parallel_decoding if self.config is not None else False self.arch_type = self.config.arch_type if self.config is not None else "llama" @@ -634,22 +620,10 @@ class Base(nn.Module): l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) else: - if self.parallel_decoding: - n_resp_tokens = n_audio_tokens + 1 - l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels - l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )] - l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels - else: - """ - n_resp_tokens = n_audio_tokens + 1 - l_embedding_tokens = [n_resp_tokens * self.n_resp_levels] - l_embedding_names = ["NAR"] - l_classifier_tokens = [n_audio_tokens * self.n_resp_levels] - """ - n_resp_tokens = n_audio_tokens + 1 - l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels - l_classifier_tokens = [n_audio_tokens] * self.n_resp_levels - l_embedding_names = [ f'NAR:{i}:{i}' for i in range( self.n_resp_levels ) ] + n_resp_tokens = n_audio_tokens + 1 + l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels + l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )] + l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1 @@ -692,6 +666,7 @@ class Base(nn.Module): n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) + self.audio_emb = None elif self.version < 5: # [1024] * 8 self.proms_emb = AudioEmbedding_Old( @@ -703,7 +678,8 @@ class Base(nn.Module): l_embedding_tokens, d_model, levels=self.n_resp_levels if self.version > 3 else None, ) - elif not self.parallel_decoding: + self.audio_emb = None + elif self.version < 7: self.proms_emb = AudioEmbedding( [n_audio_tokens] * self.n_resp_levels, d_model, sums=audio_embedding_sums == "prom" or audio_embedding_sums == True, @@ -713,17 +689,11 @@ class Base(nn.Module): sums=audio_embedding_sums == "resp" or audio_embedding_sums == True, l_embedding_names=l_embedding_names, ) + self.audio_emb = None else: - self.proms_emb = AudioEmbedding_Sums( - n_tokens=n_audio_tokens, - n_levels=self.n_resp_levels, - token_dim=d_model, - ) - self.resps_emb = AudioEmbedding_Sums( - n_tokens=n_audio_tokens + 1, - n_levels=self.n_resp_levels, - token_dim=d_model, - ) + self.proms_emb = None + self.resps_emb = None + self.audio_emb = None if self.version >= 3: self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None @@ -739,7 +709,7 @@ class Base(nn.Module): # this ***might*** let me also unify the proms_emb and resps_embedding if self.version >= 5: # "len" RVQ level-0 gets an additional token - if self.version < 7 or not self.parallel_decoding: + if self.version < 7: self.rvq_l_emb = Embedding(self.n_resp_levels, d_model) # experimental NAR-only mode @@ -747,6 +717,53 @@ class Base(nn.Module): if self.version >= 6: self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model) + if self.version >= 7: + pd_model = d_model // 4 + pd_ffn = pd_model * 4 + pd_heads = n_heads // 4 + pd_layers = 1 + + if False: + self.audio_emb = AudioEncoder( + n_tokens=n_audio_tokens + 1, # masked token + n_levels=self.n_resp_levels, + token_dim=d_model, + ) + else: + self.proms_emb = AudioEncoder( + n_tokens=n_audio_tokens, + n_levels=self.n_resp_levels, + token_dim=d_model, + ) + self.resps_emb = AudioEncoder( + n_tokens=n_audio_tokens + 1, # masked token + n_levels=self.n_resp_levels, + token_dim=d_model, + ) + + self.audio_decoder = AudioDecoder( + self.n_resp_levels, + d_model, + dict( + vocab_size=n_audio_tokens, + hidden_size=pd_model, + max_position_embeddings=max_position_embeddings, + intermediate_size=pd_ffn, + num_hidden_layers=pd_layers, + num_attention_heads=pd_heads, + attention_dropout=p_dropout if training else 0.0, + num_key_value_heads=pd_heads, + hidden_act="gelu", + is_encoder_decoder=False, + is_decoder=True, + attn_implementation="eager", + + training=self.training, + attention_backend=attention_backend, + gradient_checkpointing=self.gradient_checkpointing, + ) + ) + if attention_backend == "auto": attention_backend = "sdpa" """ @@ -906,33 +923,6 @@ class Base(nn.Module): self.classifiers = Classifiers( l_classifier_tokens, l_classifier_names, d_model, bias=classifiers_bias ) self.metrics = Metrics( l_classifier_tokens ) - self.parallel_decoder = None - if self.parallel_decoding: - pd_model = d_model # // 2 - pd_ffn = pd_model * 2 - pd_heads = n_heads // 2 - pd_layers = 1 - - config = dict( - vocab_size=n_audio_tokens, - hidden_size=pd_model, - max_position_embeddings=max_position_embeddings, - intermediate_size=pd_ffn, - num_hidden_layers=pd_layers, - num_attention_heads=pd_heads, - attention_dropout=p_dropout if training else 0.0, - num_key_value_heads=pd_heads, - hidden_act="gelu", - is_encoder_decoder=False, - is_decoder=True, - attn_implementation="eager", - - training=self.training, - attention_backend=attention_backend, - gradient_checkpointing=self.gradient_checkpointing, - ) - self.parallel_decoder = ParallelDecoder( self.n_resp_levels, d_model, config ) - def _forward( self, inputs, @@ -1126,8 +1116,8 @@ class Base(nn.Module): inputs[i].append( ( "resp", resps_list[i] ) ) if self.version >= 7: - classifier_level = f"NAR:{quant_level}:{quant_level}" if not self.parallel_decoding else "NAR" - + classifier_level = f"NAR:{quant_level}:{quant_level}" + inputs[i].append( ("classifier_level", classifier_level) ) # Audio length prediction task # Sequence: @@ -1269,29 +1259,16 @@ class Base(nn.Module): input if quant_level == 0 else input[:, :quant_level] ) - if self.version < 7: # or not self.parallel_decoding: + if self.version < 7: return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level offset = 0, ) - - if not self.parallel_decoding: - """ - # provides only one - return self.proms_emb( - input if input.dim() == 1 else input[:, quant_level], - quant_level = 0, # if input.dim() == 1 else input.shape[-1], - offset = 0, - ) - """ - # sums all - return self.proms_emb( - input, - quant_level = quant_level if input.dim() == 1 else input.shape[-1], - offset = 0, - ) + if self.audio_emb is not None: + return self.audio_emb( input ) + return self.proms_emb( input ) # yuck @@ -1358,11 +1335,11 @@ class Base(nn.Module): elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": - if self.parallel_decoding: - if dropout_mask is not None: - embedding = self.resps_emb( torch.where( dropout_mask, self.stop_token, input.t() ).t() ) + if self.version >= 7: + if self.audio_emb is not None: + embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token ) else: - embedding = self.resps_emb( input ) + embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token ) # if training NAR-len RVQ level 0 elif dropout_mask is not None: embedding = self.resps_emb( @@ -1513,7 +1490,7 @@ class Base(nn.Module): return ids.to(device=device, dtype=torch.int32) - def calc_loss_parallel( + def calc_loss_new( self, inputs: list, logits, @@ -1589,6 +1566,9 @@ class Base(nn.Module): if name != task_outputs.get(task_type, name): if self.ignore_inputs_for_loss: ignored = True + # cringe + if task_type != "tts": + ignored = True else: output_len = seq_len @@ -1602,7 +1582,7 @@ class Base(nn.Module): # perform loss calculation on the individual piece target.append( token ) - if classifier_level != "NAR": + if logits[batch_index].dim() != 3: seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) logit = logits[batch_index] @@ -1620,7 +1600,7 @@ class Base(nn.Module): if compute_acc and False: if self.metrics is not None: - metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) ) + metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) ) else: accuracy_metric = MulticlassAccuracy( logit.shape[-1], @@ -1652,7 +1632,7 @@ class Base(nn.Module): if compute_acc and False: if self.metrics is not None: - metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) ) + metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) ) else: accuracy_metric = MulticlassAccuracy( logit.shape[-1], @@ -1701,9 +1681,6 @@ class Base(nn.Module): if self.version < 7: return input if input.dim() == 1 else input[:, quant_level] - if not self.parallel_decoding: - return input if input.dim() == 1 else input[:, quant_level] - return input for batch_index, batch in enumerate(inputs): @@ -1729,8 +1706,6 @@ class Base(nn.Module): # nonautoregressive, parallel elif classifier_level.startswith("NAR:"): causal = False - elif classifier_level == "NAR": - causal = False it = 0 for name, input in batch: @@ -1773,6 +1748,9 @@ class Base(nn.Module): if name != task_outputs.get(task_type, name): if self.ignore_inputs_for_loss: ignored = True + # cringe + if task_type != "tts": + ignored = True else: output_len = seq_len @@ -1909,10 +1887,10 @@ class Base(nn.Module): # needs to be done here as we still have our raw inputs position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None classifier_levels = self.get_input( inputs, name="classifier_level" ) - casual_levels = [ "AR:0:0", "stt", "len", "phn" ] + causal_levels = [ "AR:0:0", "stt", "len", "phn" ] # right now limit to new versions because I need to retrain the model for noncausal masks... - is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ] + is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ] output = self._forward( inputs=x, @@ -1928,26 +1906,19 @@ class Base(nn.Module): logits = [ logit for logit in logits ] - if self.version >= 7 and self.parallel_decoding: - p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ] + if self.version >= 7: + p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] not in causal_levels ] if p_indices: - p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0) - p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0) - p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0) - p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ] + p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0) - p_logits = self.parallel_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal ) + p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0) + p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0) + p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ] + + p_logits = self.audio_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal ) for i, logit in enumerate(p_logits): logits[p_indices[i]] = logit - - """ - logits = [ self.parallel_decoder( logit.unsqueeze(0), attention_mask=mask, - position_ids=position_ids, - use_cache=False, - return_dict=True, - is_causal=is_causal )[0] if level == "NAR" else logit for logit, level in zip(logits, classifier_levels) ] - """ # output projection layer # the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways @@ -1958,15 +1929,6 @@ class Base(nn.Module): elif self.classifiers is not None: logits = self.classifiers(logits, levels = classifier_levels ) - # Reshape - """ - if self.version >= 7 and not self.parallel_decoding: - for batch_index, logit in enumerate( logits ): - if classifier_levels[batch_index] != "NAR": - continue - logits[batch_index] = logit.reshape( logit.shape[0], 8, 1000 ).permute( 1, 0, 2 ) - """ - # Remove padding logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ] @@ -1977,8 +1939,8 @@ class Base(nn.Module): self.loss = None self.stats = None # compute loss if the target is given - elif self.version >= 7 and self.parallel_decoding: - loss, stats = self.calc_loss_parallel( inputs=inputs, logits=logits ) + elif self.version >= 7: + loss, stats = self.calc_loss_new( inputs=inputs, logits=logits ) # include any additional losses (for example: MoE router) if output.loss is not None: