From 7feeb944a082a920a132ea6e80dbfcaf311d7517 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 3 Jun 2024 20:26:27 -0500 Subject: [PATCH] probably insane with even entertaining going this route --- data/config.yaml | 283 ++++++++++++------------ data/qnt.enc | Bin 0 -> 5441 bytes data/qnt.pt | Bin 20599 -> 0 bytes scripts/process_libritts.py | 2 +- vall_e/models/ar_nar.py | 8 +- vall_e/models/base.py | 4 +- vall_e/models/experimental.py | 404 ++++++++++++++++++++++++++++++++++ 7 files changed, 553 insertions(+), 148 deletions(-) mode change 100755 => 100644 data/config.yaml create mode 100644 data/qnt.enc delete mode 100755 data/qnt.pt create mode 100644 vall_e/models/experimental.py diff --git a/data/config.yaml b/data/config.yaml old mode 100755 new mode 100644 index c2ea978..a46cec6 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,140 +1,143 @@ -models: -- name: "ar+nar" - size: "full" - resp_levels: 8 - prom_levels: 8 - tasks: 8 - langs: 2 - tones: 1 - arch_type: llama - training: True - version: 4 - attention: flash_attention_2 - dropout: 0.1 - - loss_factors: - text: 0.1 - resp: 1.0 - -hyperparameters: - autotune: False - autotune_params: - start_profile_step: 1 - end_profile_step: 50 - num_tuning_micro_batch_sizes: 8 - - batch_size: 16 - gradient_accumulation_steps: 4 - gradient_clipping: 1.0 - warmup_steps: 100 - - optimizer: Prodigy - learning_rate: 1.0 - torch_optimizer: True - - scheduler: "" # ScheduleFree - torch_scheduler: True - -evaluation: - batch_size: 8 - frequency: 5000 - size: 8 - - steps: 500 - ar_temperature: 0.95 - nar_temperature: 0.25 - load_disabled_engines: True - -trainer: - #no_logger: True - ddp: False - #check_for_oom: False - iterations: 1_000_000 - - save_tag: step - save_on_oom: True - save_on_quit: True - save_frequency: 250 - export_on_save: True - - keep_last_checkpoints: 4 - - aggressive_optimizations: False - load_disabled_engines: False - - #load_state_dict: True - strict_loading: False - #load_tag: "9500" - #load_states: False - #restart_step_count: True - - gc_mode: None # "global_step" - - weight_dtype: float32 # float16 or bfloat16 - amp: False - - backend: deepspeed - deepspeed: - inferencing: True - zero_optimization_level: 0 - use_compression_training: False - - amp: False - - activation_checkpointing: True - - load_webui: False - -inference: - backend: deepspeed - audio_backend: "dac" - normalize: False - - weight_dtype: float32 # float16 or bfloat16 - amp: False - -optimizations: - injects: False - replace: True - - linear: False - embedding: False - optimizers: True - - bitsandbytes: False - dadaptation: False - bitnet: False - fp8: False - -experimental: True # practically required now it seems - -dataset: - speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" - speaker_group_getter: "lambda p: f'{p.parts[-3]}'" - speaker_languages: - ja: [] - - use_hdf5: True - use_metadata: True - hdf5_flag: r - validate: True - - workers: 2 - cache: True - - duration_range: [3.0, 5.0] - - random_utterance: 1.0 - max_prompts: 1 - prompt_duration: 3.0 - - max_resps: 1 - p_resp_append: 0.25 - - sample_type: path # speaker - - tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"] - - training: [] - validation: [] - noise: [] +sample_rate: 24_000 # 44_000 for dac +audio_backend: "vocos" # or dac + +models: +- name: "ar+nar" + size: "full" + resp_levels: 8 + prom_levels: 8 + tasks: 8 + langs: 2 + tones: 1 + arch_type: llama + training: True + version: 4 + attention: flash_attention_2 + dropout: 0.1 + + loss_factors: + text: 0.1 + prom: 0.0 + resp: 1.0 + +hyperparameters: + autotune: False + autotune_params: + start_profile_step: 1 + end_profile_step: 50 + num_tuning_micro_batch_sizes: 8 + + batch_size: 16 + gradient_accumulation_steps: 4 + gradient_clipping: 1.0 + warmup_steps: 100 + + optimizer: Prodigy + learning_rate: 1.0 + torch_optimizer: True + + scheduler: "" # ScheduleFree + torch_scheduler: True + +evaluation: + batch_size: 8 + frequency: 5000 + size: 8 + + steps: 500 + ar_temperature: 0.95 + nar_temperature: 0.25 + load_disabled_engines: True + +trainer: + #no_logger: True + ddp: False + #check_for_oom: False + iterations: 1_000_000 + + save_tag: step + save_on_oom: True + save_on_quit: True + save_frequency: 250 + export_on_save: True + + keep_last_checkpoints: 4 + + aggressive_optimizations: False + load_disabled_engines: False + + #load_state_dict: True + strict_loading: False + #load_tag: "9500" + #load_states: False + #restart_step_count: True + + gc_mode: None # "global_step" + + weight_dtype: float32 # float16 or bfloat16 + amp: False + + backend: deepspeed + deepspeed: + inferencing: True + zero_optimization_level: 0 + use_compression_training: False + + amp: False + + activation_checkpointing: True + + load_webui: False + +inference: + backend: deepspeed + normalize: False + + weight_dtype: float32 # float16 or bfloat16 + amp: False + +optimizations: + injects: False + replace: True + + linear: False + embedding: False + optimizers: True + + bitsandbytes: False + dadaptation: False + bitnet: False + fp8: False + +experimental: True # practically required now it seems + +dataset: + speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" + speaker_group_getter: "lambda p: f'{p.parts[-3]}'" + speaker_languages: + ja: [] + + use_hdf5: True + use_metadata: True + hdf5_flag: r + validate: True + + workers: 2 + cache: True + + duration_range: [3.0, 5.0] + + random_utterance: 1.0 + max_prompts: 1 + prompt_duration: 3.0 + + max_resps: 1 + p_resp_append: 0.25 + + sample_type: path # speaker + + tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"] + + training: [] + validation: [] + noise: [] \ No newline at end of file diff --git a/data/qnt.enc b/data/qnt.enc new file mode 100644 index 0000000000000000000000000000000000000000..8da3c31017e8ea913791a99b34c2cfb0d73796e2 GIT binary patch literal 5441 zcmbtYd3;S*+kMvaM73gU%tI4Jq!rZA)>tvMMI=NKvzz3S7{ZMk8WmEiW|Eqz7ZF3N zK?$l_^V}F)Ld85(#L&b{@vW=h`^We9m-EZH=j?O#e)fKbwbni}Lprzb&tTi}(kUPc^7w1SQ)F;jvIV5(NJ1#!bUC1T>zEvnT>ffg>w3G2J zZHWk%O>lXIxe5fd!#n@`Qj|MksI!pkwQv=Y&<3u8lO|4Zy&mRzBcPpygKeW+g+f9? z+P>u9t5dsXZkKn&q+YJViLN36;RX0gozLR_%Z@IeJ z-0-)V7OaJE&8}HdtiW(9ZO8SgC1asg(&Kth>gsR_&;XrgB_vawe6Ir@Nj)BA?Ig#| z5jV=g8-H6pTG(v2q&|vEE#&jAj24#;{QXVG=;v&msd=_i_DG^E=6|+jTWwsHY8ZtJ zTA%N>)&HnBuRUiUwAFJqlVi+LkF}=S#R4Trx>y(X3wqT%NT8lkJa<7-?bY>v?}BK@ z^X%VaJ^NMD2^^z~P0(yi&_fcacdVF(*kc^B;_|n7tZtHSaC64fwTd0o(b%Okaoo~% zr##UKvQxa#Pam>vmIjK)qi@R*8A1>F*l?+U=a!-+HAAl1R>_iWm~Mxyu+Gx^@|$|R z!n1iJ&r*+KCIhsx?Lh0zBc z8Oap(4Y7AQ--77IyF)Uaeo~|}(g*M@W4zpc)y|g1Sp-|4-L-4jhDp{$cbKnEes#S| zH?r??v_hix!3}1?2q|j!<&>V}Zy#E(EVPD?tsCykCT)YDSKjydx(mJCoOeZ$ZR`0P zB!}qTAXzWdxPryV;j9;Fylm7s>_QxFTRgA)Esal-aa(3GZbL9xdf;;nkWfiOUHW4y zTHD9^ri|i$Q9Yv9^{Cd-UJ_#qteAM~22?>0c*!h1ZPAj4# zN-v@Wdg^3P6(H4C$pi%QIKmA7Op7Cl9ywO3SyC<0#U}=(#Nc56)?53YrBFc(dF2~E~FCI)E32fk{(@_A z6%QntWA4%Ns41zo-2#zozuS7fDS?tM-R&0kvR)JAjC9a1bd5}5t!LW+sVOdLEeCBh zin6*V%6ZPFrfw2H%+qe#R_gN@ii`54jJ2`0Tx#0;nySmpM;(%&5Ah*aa0!JqK@Zv` zqRl0)#ozkr2G0IC>cB0Vqz)1|f6v&}F&by>Ij^tsToLDqW819}excQI_BAmfkLa*Q zis)ea_5j=RjABwdMvauAi|EV{{h{ZQHy0XD*0LoLPH*cjBuRgU_NHPG2s7l!g! zOvD|ewONTzWw9+K;+>JDTz@0^6f?vNrS&q#a;+g+j2<7z99z%KS);$JLw}QSoY7xx zEdts1i~vPh;jy|vV%5+c&2)~Pv|VUJzn_I4uIV_&&W~}}LA04GIrPkWzLlxxnR)4G zD643>xu{;{h)bm~?*e&u4Lx)-Qe}y>(A!c-R^zE9>M<#1XR*fC*+cQcKD&-pIH+rx zMP)P@&lr_3OtvZVr??nV7jq(nSsaXu_6NQ1$s0d=f{`*&w(BPRi1E4%eI-x3S`Czx zP%WniaLf*%rQ7w|7b65V@pv>ej*B!L2aGBCG~7OZlE9IK1NnZ zYwJp!ZfTXJujXJraa1f*$J#|(YKs|qj)hXI4%NOOpOb%Ps7EvP5;8 zg=9;%+3b^S1u%kjRUE#yMsHv{qww68=>Yr5;>ilZ($(tN3w$r9<}U zqSIpgn&|n25pRaC5pON+3E4BAvzUeNkgq>kLCcU-c#(-_^QxaFksZSHrR8X4R_+GY z@L8hnV0B}#<=7m~^Ezi)0@H~!UbY7(t)93_y0uT74u}qPMILulti6AVNg`B}l8vr*~=hk)7ggtN2D;fHo z&a=zfTmB<{nk4PTAs3J=p(ub-@?4r*1j6(m*+yJzOx#?6g0>1#+Md{TT0C=HAp^7& zqdE=Sbg8VC`qmRmEC)?(nB0)k=4JmNn2hgbt<1*;+hiMOTdawVx4*DWw#i_NBJ!58 zox~9z`x&|Lp!Gyav95Z7E$@=W=g4cCfzjGnrZQ*dX)2MfgM5vv@{F00F3a?czHVD2 zM6>m%_)>8dwEc`+3}#6uozM8?sF;`TMGJUogf5hiYyw1%%MUiyD&v85l+N_wF)PPR zm@oH-9Z)idN)DGH4Vq$F#Cnq&^&78Mrfny)H5B)Z2me_i(sntMsHM+$`U6rx>?JbAFC!vn_C=J z@@#y>EPoGenf)!$N5*PX6APtpyRy9nGj5~$$R=iPvX0VL%=0`({)#SFSV?WHj}a&N z@`3$DRa#NHU^mi;m{U=k*|Cb*`6ay>$UgJsrZ%Fg%*PEPU?s#zUu!`dYo(>_)o%fY`8K3(yA%R2d^QPlL!vzOune%O6ORvf7IZ?zuM5(bkFB)>6jPBiTfg zNFr=5cUV7Y8El~v??5G3U7O2R)_5l8{JYJu{T5-p^*l4OIGJXJ^^@z!Vl_RI!}MG@ zS!6kL;tG*35?OXtj%p7yBO)xIzXa8_6$=%;wVtNVb!TJo+<1)%_;vZ_~S<+E=lPwO*1!=%s`pfE5 z@AqZ{%O)>emv`-R zJxpebwFfeb`R|X*`W0~?o*A=1F3D~(N=tmn>`0`(d4}1Ns99Q`*7GZjbHe4kA}C zP**vb+4b-md260I%+e^N@O_DOJvS-A}er@?9t@vt=z+V?*@B zClbLrjk38m7YFnXd0-^jIN}WbM}w)~ z7uZa^XYtWWx{#SST*9mk zza*^HTojW(Ih&E%l69Gg)l~hTa=o2x5v!@57PXpWoou~>VmN{X#``5wsQc371Ffv@ zllNBY-!hOqzd>tK8Rbi7>W$XAmROLbpWEBiewT=j8PxfQvB+-Q30X>=(vv!z>YqEQ zmKrYMmaN5zH&+?2JyeIKxL^MjWlSZ9Ukq-^0V4PWTAt6GeUtjI34PPqP9T%~-kTgy z4G$Twt@0y%`a(hFIK#3K&fRux;`blUsRN7R} zm>+ZGO{#c*y^OEPGERPPxU3WPSJYt!&nIr3;XbdXG(bHu)=U6ndBKuwH}y?NOe50= zbIni8Ro3k)7ZLuthzxeR9Z?Rq!&N@fRUsg}7@LN~#q^Jfbqwk=$Qj$;9qp?4NwX;- zUuN8>5E|~y$HN?hhYoV~iF3G}u1f89^m;YSAV+Nfct?My>#bMc6rHgy-}rck+vOKg G>wf@xoP|9A literal 0 HcmV?d00001 diff --git a/data/qnt.pt b/data/qnt.pt deleted file mode 100755 index a99b73bda4ad2df4729ad5cc26610b2f16c8a5e9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20599 zcmai+3w+P@{>OjrOm0z>G>hbv>!u6kxH~eJAvL+oY*RzlW|UHvl3Q+(Ys!W$Ixa~R za$SW8-6TSEaTii4bhfkS)A9eFzyIIk(c`uE{r$PUKbP-XOmu;Qp-|DHq5t|tp9 zH?B^dx_3m@O6{LsJ25^zzSg6C`?YKu7JBf~zl{6|=_zRm57&wtn4a8kKv-N_Qm=u@ z{SxESlll)xNsAj48I~HVUOl!*j{gBUn(+@MrRK{s32T$m|DjHKys*^#vH5aT>ZBHk zEs%o=j}A;N*miNT^wdHfqC?U7qKme@Bj1SB!a1y*AEI-9EHWw~A*Elxq=fY3l>P(4 z+NUKZr6nacPfkb=OD+1hj;T4OVI5P8*NTZQQuO@GN!v4XyypHFbHxT<9SZ-i2G_3B zvT4Q3ntQ3;{*aWLZO*@_JP4J~pZCMW+}7m$ldgl_BJf+}Z|13mPGdJ5`BCWQ(CO$` zgPsJoK`#vX5Aem3@LvN5A-@*ff&CxQ6VRIpT?x5*XUI1a{vhb;(BfF^%OKx^-2O7Z z7~=hgp8Y6Rz|WV+?a#x|_RDqnDGvUOT@me@d}%Pl2xn=soC*wefoZy_=zjfcuD34>aG_(64~@llnn^S^R1DEqT}u z+dTz;j+fXF{{7h1Am7HIe9`tPmloyc!>@c%d62&O3i?64Agw>gNBRGUOG0`7Y?tvz z1o}g{{5|xy^k)3j!0rV7SPHrw^zYbhhIYK2|G{xxhTksOxqckaarm!-zfI6jqrV<{ z67&xA7QkPJ|8vkIpyfO7dL#ce;BPXt^CehEZ^GZFU<&%yp|a#7U6DL5z@LbIIdCa{ ztVb)McY>?2ABX)`@G9)h_ZRf71Hrm-iI2VgU_G#%;!NT?&emJ&S8wFD?>PBLuOOc` z#7hSC`!V(#!Pl|h3oTwHjeMQ$hra~67-$`iK)*P7=qC=l4#;h&5x&e9n*l0`jvy+3uI&yC2Vk{(!ihp`8z+{aOS0N$>~aId4Y6zXyM7p~Im^UqRn> z_HLlP^Pro9PvOUYJ%oSH1&`7HCE+_yzrxOaRk?l~uL$hrcOY&{{5cN#U54L#(QAyo zXx)>xo(B0%0(PEb)VI#+N3;%Fcjc?^xx{twDtR8JJ+^xm?HEWN??R`czaEr6j{U%Z zp69i>fjFMCj^lq4=>C?1{%H7Jp&!S8b@=T;{YRo_`&L4?AitK_Uya?609_LLQ}9P{ z8(0_Yj=gnd2L4OIFOJ?j#IFL{j+{%!zyDc3zsCMy_@~gn8(KfEYxf(+?G$>|v6~87 z-<{vui_^%{`FapPv0!WRdJo$1eh@opaUA+%iQgBxDs(5X2mFK3&M(K!arT`167jT~ zkG)d)}eR_;M{~pkn|2y(C#5sV!6X<7w?(cU(FGg=I z_MY43LT^LwB7VI$dv2E>g?^C#HvX-<`=MQz-LM-CI=&;&Gw(NuFTVzI(ettA@cZyr z8odnKVIA>)Hxj>&^M}yhGh8R1;ja*WZUO6~@46IS-_957Tut(+kNsZ!y8eHGe++bf z%qPxx^mig(j=$fLJASU8XykK{cZYVKxbCH0r+x76`gUAgC)VKw_}hTK{MX?p!goJ9 z3T^+{PvSV*=Q?~2Kh|N_t7u&J8~3jg*m=*XM*FP?)*s&kh7;HQ%==W3?>?cP=LPer zP8{p+S^VBeKGydO(5cAv`xN#;98SL>4lU$|j`$IK;cqe654#)CTZWzY&_d|FfV>a# zx3JT%_ww!NTVHL*xA4_pgukiqMg5A&=t=iOzaji^@^Srp&sv522jt&FpM{=4{`O-w zey>NbCVJMFHRxMUjzWJ9F2Vi?cHY+)KtG7yG5mjszV}!2*oA%!@+U#Za}0Dj(DT3J zV&3kfE%3LA_)inx^;R6(I=2^n$EhZ9oEMeR_ng=Td28(I1@KYy_rre%97aCkJoq)? z-;Uoba1Q*k(DCRGhW?c}?pugM-@*5NLVf!~x$id8w%_{f`^-D!p?nAa>w@V7w!h>X zSG!<;--5m4vzs`sAN9Wv@Lji#o9oc|+5o?+$jkW8;^%q#&v|eP{dUNmcO}q!1V8eB zgLZvd?@z+_ypw^wb$=7`O~k7Ny&Jjw*3H zB6@e>|0U>1>@dKP*K5b1)wK1iM(`?1KLR_SSRz!}i*qeAxMJ;k@XNpBea< zw!a_5t`T}ui1#JvyMy;t>$~&Fdi4YL&U@EQW#T#?)=Znb0b5VI{-+}&wc3-h>jmGa7;(Y^p&vGC2++!^Ne5%)dh-WwyK2a$*8%p1{< zg1?4%hrr(O-QS0z{~G9fo&EJKcFnMhK(8n94G(GykNfbe-1Y3f_%`;AYg6=Z$6mhoqRH?*$BOP-<|V%n`j6pPKi+@C z;amS5mo@aOe7_sbAz$x-?_oa>jKZ!B=(yEIevy3UBlmsoH~jie>3-|F+kt!q_VPVX zr@?ppo`=?-`ucTz)H8qQ)j9IZ29-N69EUfFbDI3@k0P|!_^z{p@QZTh z{Eqnl3p@9v2WG5Yre+H2m{ zm8InICO8d$2cUak{|Z|J)MErv2eTaQiP<}7u zGx6hhff4Az4(;HFf8zHAFao`^&>hh04Bvk8JbE2+&k5%L3;aFUeFxq~e!Y;N0vm%5 zqW2r|lfi84lAx=jp9245(EHHi_^E+@f9TcFjz??w{{Y?p&!G1Wer;D@^dg9V0KEkC z9LEpgF9LtYZZUL4XxH;X=;`2Q{2c;|fUU6K1ilM@IdpC4SnzB3^~rYyxCHrU$ZsI- zM(E1u%|_4t;vjn7BYi)2J^JoweA_3+beR;Pru$n_2+o4 zB#!;Q82S{bfAxJob$;pJ@e;k?k0rmWK))NN@Zodn@|Sf_1PPkKH`-vtCBS-;ADhi7 zhv#zVz5VXF-}xO!9*)1~GWXZ^#I@c!@AU6I$a9?aw=8~r2lAZjycvSu=HO!VtwXl^ zAad7hKKi8`{JX*5X}|q)8h>}7{}XzC&$|Kn{n%{)=c4Dki1oqwYrnjKU(ajLU}xR* zUDEc*KZ0HZ(0)8keD}BV$Sc8j{q-bnQ{;}vBHH7A?fQ5bzV-WB=rP#02IIg_@vpw; zvUka&4SHE%CisEz(M!aSddJ~!1AjvP8TO91?}J^DUxR!c=>Ag%KkoZ2pbLRV$fp6= z4txl|^24Fq2hjVe>#IEW>h&Za&x>C|doTB$WHff(bIs5F<~8hk5q}K%PlRtjIxoGq zIsYF(ZxD79h^N1Y;eQ22px=^szVDwzerdi#k3!$|e**s*@Lm6F@#p)E{j(GOMd;VY z@B7eyAoo0@T_k+Zt7o8HzrN3Vp8Grge5d^ieeDjA-*v=y{SAlqoz{15zr&8g&i!XI zwC9BB#N7bj_oqepjYe*}9G^z$zYelk4gh^qUZOB{&DZ=Nsqq=lGcdehgX{xpmd|1m~gm-4*D~z>oXgX!u{icU&hCU)p`l{m=c; zdiFByn*qAt`d;1;|1W}5!41UUPCVZkIwSXfWq-IoypMi5m_=OIk>`mm@H^qp`u7)c zeNXY-aT|QpL+*d_7b5T;?78Gg_!*%65yGG2wSau5fkVMQ*g4;9x99z1_-PL|#m;vD z*T)w83<1~B{!P%Hx2&h9k$eC82tDh?pU}HO*Ry$df!~OH+^@2s7m}~%l6vqR$5*j0 z19k;}1{K5gvUei)-L3<68$iEXSReMGZ@snNTmQ}TF!6c@{NGCaBlsVO z-eB}RcQ1z@4d3zd_i;QYyZ?1W|3m!y4%`kp3qL)6#QY}pP={nr|`cMj709aKMJ}9@-wu|3GdxH`^t zkpFEZmr>;!G1P&>peZ6fw>qi?-*e~v(YJMHa`pV#4g&*%$16TNq!YohOZ@;%>n4oBVsec!)*FYuh1 zO@23_e+&2>@j78|z0H93oIMMD_s!3+%Y^Ses2g^lAorfJ5q;d_F^O z3kk%6zCU?>z8O91x%<~j{Jl+l>*+=ONH@}_Gv7mX?hIU>K!9ErLzJsI_cRBoX z&^M!B5kI~QFGkOKdOi9x(03m^1D%3?8uV@GcgOE+UHv?t)eL$AT#F6=z_*F=67ac!6PLgQA)-u-ePbOz}Ar2Dn+ah?~nYeRhJ=Q-?i|Ge9< z*{!x@v<-!F|GV;^cPkre^gsT*+n}Vh0Xcs Tensor: - if cfg.audio_backend == "dac": - qnt = np.load(f'{path}.dac', allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) - return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16) + qnt = np.load(path, allow_pickle=True)[()] + return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) - qnt = _load_quants("./data/qnt") + qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") text_list = [ diff --git a/vall_e/models/base.py b/vall_e/models/base.py index f0a019e..6a86ca3 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -944,8 +944,8 @@ class Base(nn.Module): "logits": [], } - info[name]["targets"].append( input ) - info[name]["logits"].append( logit ) + info[name]["targets"].append( input.contiguous() ) + info[name]["logits"].append( logit.contiguous() ) for name, batch in info.items(): loss_factor = self.loss_factor(name) diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py new file mode 100644 index 0000000..606a839 --- /dev/null +++ b/vall_e/models/experimental.py @@ -0,0 +1,404 @@ +from ..config import cfg + +import torch +from torch.nn.utils.rnn import pad_sequence +from torch import Tensor +from torch.nn import CrossEntropyLoss + +import random +import math + +from einops import rearrange +from tqdm import trange + +AVAILABLE_ARCHES = [] + +try: + from transformers import LlamaForCausalLM, LlamaConfig + AVAILABLE_ARCHES.append("llama") +except Exception as e: + print("Error importing `llama` arch:", e) + pass + +try: + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig + AVAILABLE_ARCHES.append("mamba") +except Exception as e: + print("Error importing `mamba` arch:", e) + pass + +def _create_mask(l, device): + seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) + stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) + return (seq < stop).float() # (b t) + +def list_to_tensor(x_list: list[Tensor]): + l = list(map(len, x_list)) + x = pad_sequence(x_list).t() + + m = _create_mask(l, x_list[0].device) + m = m.to(x) + return x, m + +# fold into a typical LLM sequence (one embedding rather than split embeddings) +def fold( + text_list = [], + proms_list = [], + resp_list = [], + + ignore_index = None, + + sep = 3, + stop = 3, + + text_tokens = 256, + audio_tokens = 1024, + audio_rvq_levels = cfg.model.prom_levels +): + + device = text_list[0].device + batch_size = len(text_list) + input_ids = [ [] for _ in range(batch_size) ] + + offset = 0 + + sep = torch.Tensor([ sep ]) + stop = torch.Tensor([ stop ]) + + for i, text in enumerate(text_list): + seq = text.to("cpu", dtype=torch.int64) + input_ids[i].append( seq ) + input_ids[i].append( sep ) + + offset = text_tokens + for i, prom in enumerate(proms_list): + if ignore_index is not None: + seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64) + else: + seq = prom.flatten().to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + + input_ids[i].append( seq ) + input_ids[i].append( sep ) + + offset = text_tokens + (audio_tokens * audio_rvq_levels) + for i, resp in enumerate(resp_list): + seq = resp.flatten().to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + input_ids[i].append( seq ) + input_ids[i].append( stop ) + + for i, batch in enumerate(input_ids): + input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64) + + return list_to_tensor(input_ids) + +# unfold from one unified token ID space to separate token spaces +def unfold( + input_ids, + + sep = 3, + stop = 3, + + text_tokens = 256, + audio_tokens = 1024, + audio_rvq_levels = cfg.model.prom_levels +): + device = input_ids.device + batch_size = input_ids.shape[0] + + text_list = [ [] for _ in range(batch_size) ] + prom_list = [ [] for _ in range(batch_size) ] + resp_list = [ [] for _ in range(batch_size) ] + + for i, batch in enumerate( input_ids ): + for idx, token in enumerate( batch ): + id = token.item() + if id == sep or id == stop: + continue + + if 0 <= id and id < text_tokens: + text_list[i].append( id ) + elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels): + prom_list[i].append( (id - text_tokens) % audio_tokens ) + elif text_tokens + (audio_tokens * audio_rvq_levels) <= id: + resp_list[i].append( (id - text_tokens) % audio_tokens ) + + prom_len = len(prom_list[i]) + if prom_len % audio_rvq_levels == 0 and False: + prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t() + else: + bins = [ [] for _ in range(audio_rvq_levels) ] + for pos in range( prom_len ): + rvq = pos % audio_rvq_levels + bins[rvq].append( prom_list[i][pos] ) + nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels + bins = bins[:nearest] + prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) + + + resp_len = len(resp_list[i]) + if len(resp_list[i]) % audio_rvq_levels == 0 and False: + resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t() + else: + bins = [ [] for _ in range(audio_rvq_levels) ] + for pos in range( resp_len ): + rvq = pos % audio_rvq_levels + bins[rvq].append( resp_list[i][pos] ) + nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels + bins = bins[:nearest] + resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) + + text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64) + + return dict( + text_list=text_list, + prom_list=prom_list, + resp_list=resp_list + ) + + +SELECTED_ARCH = cfg.model.arch_type +if SELECTED_ARCH not in AVAILABLE_ARCHES: + raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available") + +if SELECTED_ARCH == "mamba": + LlmArchClass = MambaLMHeadModel +elif SELECTED_ARCH == "llama": + LlmArchClass = LlamaForCausalLM +else: + raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available") + +class Model(LlmArchClass): + def __init__( + self, + d_model=1024, + n_layers=12, + n_heads=16, + p_dropout=0.1, + + attention_backend=None, + activation_checkpointing=True, + ): + + if SELECTED_ARCH == "llama": + super().__init__(config=LlamaConfig( + vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1, + hidden_size=d_model, + max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.prom_levels * 60, # max-length of 60 seconds + intermediate_size=d_model*4, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + attention_dropout=p_dropout, + num_key_value_heads=n_heads, + sliding_window=cfg.dataset.frames_per_second * cfg.model.prom_levels * 12, + hidden_act="gelu", + is_encoder_decoder=False, + is_decoder=True, + attn_implementation=attention_backend, + )) + + if activation_checkpointing: + self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) + elif SELECTED_ARCH == "mamba": + super().__init__(config=MambaConfig( + vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1, + d_model=d_model, + n_layer=n_layers*2, + #ssm_cfg={"layer": "Mamba2"}, + )) + + + def forward( + self, + *args, + **kwargs, + ): + output = super().forward(*args, **kwargs) + + if SELECTED_ARCH == "llama": + if output.loss is not None: + self.loss = dict( + nll = output.loss, + ) + elif SELECTED_ARCH == "mamba": + if "labels" in kwargs: + logits = output.logits + labels = kwargs.pop("labels") + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + self.loss = dict( + nll = loss, + ) + + return output + +def example_usage(): + cfg.trainer.backend = "local" + cfg.hyperparameters.gradient_accumulation_steps = 1 + if cfg.audio_backend == "dac": + cfg.sample_rate = 44_000 + + from functools import partial + from einops import repeat + from tqdm import tqdm + + from ..emb.qnt import decode_to_file, unload_model + from ..engines import Engine + from ..utils import wrapper as ml + + import numpy as np + import re + + device = "cuda" + + def tokenize(content): + return torch.tensor( cfg.tokenizer.encode(content) ) + + def _load_quants(path) -> Tensor: + qnt = np.load(path, allow_pickle=True)[()] + return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) + + qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") + + + text_list = [ + tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), + #tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), + ] + proms_list = [ + qnt[:cfg.dataset.frames_per_second, :].to(device), + #qnt[:cfg.dataset.frames_per_second, :].to(device), + ] + resps_list = [ + qnt[:, :].to(device), + #qnt[cfg.dataset.frames_per_second:, :].to(device), + ] + + text_list = text_list[:1] + proms_list = proms_list[:1] + resps_list = resps_list[:1] + + input_ids, attention_mask = fold(text_list, proms_list, resps_list) + target_ids, target_attention_mask = fold(text_list, proms_list, resps_list, ignore_index=-100) + prefix_input_ids, prefix_attention_mask = fold(text_list, proms_list) + + kwargs = {} + model = Model(**kwargs).to(device) + steps = 50 + + optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" + scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else "" + learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None + + if cfg.optimizations.dadaptation: + # do not combine the two + if scheduler == "schedulefree": + scheduler = "" + + learning_rate = 1.0 + + if optimizer == "prodigy": + if learning_rate is None: + learning_rate = 1.0 + + optimizer = ml.Prodigy + elif optimizer == "adagrad": + if learning_rate is None: + learning_rate = 1.0e-2 + + optimizer = ml.Adagrad + elif optimizer == "adamw": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.AdamW + elif optimizer == "sdg": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.SGD + else: + raise ValueError(f"Unrecognized optimizer: {optimizer}") + + print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) + + optimizer = optimizer(model.parameters(), lr=learning_rate) + + if scheduler == "schedulefree": + if isinstance(optimizer, ml.AdamW): + scheduler = ml.schedulefree.AdamWScheduleFree + elif isinstance(optimizer, ml.SGD): + scheduler = ml.schedulefree.SGDScheduleFree + else: + scheduler = None + + if scheduler is not None: + print("Scheduler:", scheduler) + optimizer = scheduler( model.parameters(), lr = learning_rate ) + + if cfg.optimizations.replace and cfg.optimizations.linear: + model = ml.replace_linear( model ) + + if cfg.optimizations.replace and cfg.optimizations.embedding: + model = ml.replace_embedding( model ) + + engine = Engine(model=model, optimizer=optimizer) + + torch.save( { + 'module': model.state_dict() + }, f"./data/{SELECTED_ARCH}.pth" ) + + print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + + @torch.inference_mode() + def sample( name, steps=cfg.model.prom_levels*cfg.dataset.frames_per_second*60 ): + engine.eval() + if SELECTED_ARCH == "mamba": + output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3) + else: + output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False) + + unfolded = unfold( output ) + for i, batch in enumerate(unfolded["resp_list"]): + _ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device) + + unload_model() + + def train(): + engine.train() + t = trange(steps) + for i in t: + stats = {"step": i} + if SELECTED_ARCH == "mamba": + stats |= engine.traverse(input_ids=input_ids, labels=target_ids) + else: + stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask) + stats |= {"grad_norm": engine.get_global_grad_norm()} + + tqdm.write(f"{stats}") + + torch.save( { + 'module': model.state_dict() + }, f"./data/{SELECTED_ARCH}.pth" ) + + #sample("init", 5) + train() + sample("final") + +if __name__ == "__main__": + example_usage()