From 071fb97777ab732c16e72317fefc35870f51469d Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 21 Apr 2024 14:49:18 -0500 Subject: [PATCH] dataset preparation script updates, caved and am using HF tokenizer now --- data/config.yaml | 137 ++++++++++-------- data/qnt.dac.pt | Bin 80456 -> 0 bytes scripts/prepare_librilight.py | 4 +- ...old_dataaset.py => process_old_dataset.py} | 57 +++++--- scripts/train_tokenizer.py | 57 ++++++++ vall_e/config.py | 20 ++- vall_e/data.py | 32 ++-- vall_e/inference.py | 11 +- 8 files changed, 211 insertions(+), 107 deletions(-) delete mode 100644 data/qnt.dac.pt rename scripts/{process_old_dataaset.py => process_old_dataset.py} (73%) create mode 100644 scripts/train_tokenizer.py diff --git a/data/config.yaml b/data/config.yaml index 82ccb35..5f106a3 100755 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,51 +1,23 @@ -dataset: - training: [] - validation: [] - noise: [] - - speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" - - use_hdf5: True - use_metadata: True - hdf5_flag: r - validate: True - - workers: 2 - cache: True - - phones_range: [4, 512] - duration_range: [1.0, 32.0] - - random_utterance: 1.0 - max_prompts: 3 - prompt_duration: 6.0 - - sample_type: speaker - - tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"] - models: - _prom_levels: 8 - _max_levels: 8 - - _models: - - name: "ar+nar" - size: "full" - resp_levels: 8 - prom_levels: 8 - tasks: 8 - arch_type: "retnet" - training: True - version: 3 +- name: "ar+nar" + size: "full" + resp_levels: 8 + prom_levels: 8 + tasks: 8 + langs: 2 + tones: 1 + arch_type: "retnet" + training: True + version: 3 hyperparameters: - batch_size: 8 - gradient_accumulation_steps: 32 - gradient_clipping: 100 + batch_size: 4 + gradient_accumulation_steps: 4 + gradient_clipping: 10 - optimizer: Prodigy + optimizer: Adagrad torch_optimizer: True - learning_rate: 0.0625 + learning_rate: 1.0e-2 scheduler_type: "" #scheduler_type: OneCycle @@ -67,22 +39,24 @@ hyperparameters: # decay_mom_rate: 0.0 evaluation: - batch_size: 16 - frequency: 250 - size: 16 + batch_size: 8 + frequency: 10000 + size: 8 - steps: 450 + steps: 500 ar_temperature: 0.95 nar_temperature: 0.25 load_disabled_engines: True trainer: + no_logger: True + iterations: 1_000_000 save_tag: step save_on_oom: True save_on_quit: True - save_frequency: 100 + save_frequency: 250 export_on_save: True keep_last_checkpoints: 4 @@ -91,33 +65,82 @@ trainer: load_disabled_engines: False #load_state_dict: True - #strict_loading: False + strict_loading: False #load_tag: "9500" #load_states: False #restart_step_count: True gc_mode: None # "global_step" - weight_dtype: bfloat16 + weight_dtype: float32 amp: False backend: deepspeed deepspeed: + inferencing: True zero_optimization_level: 0 - use_compression_training: True + use_compression_training: False activation_checkpointing: True + load_webui: True + inference: - use_vocos: True + backend: deepspeed + audio_backend: "dac" normalize: False - weight_dtype: bfloat16 + weight_dtype: float32 amp: False bitsandbytes: enabled: False - injects: True - linear: True - embedding: True - \ No newline at end of file + + injects: False + replace: False + + linear: False + embedding: False + + bitnet: False + +fp8: + enabled: False + backend: "te" + +experimental: True + +dataset: + speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" + speaker_group_getter: "lambda p: f'{p.parts[-3]}'" + speaker_languages: + ja: [] + + use_hdf5: True + use_metadata: True + hdf5_flag: r + validate: True + + workers: 8 + cache: True + + #phones_range: [4, 512] + #duration_range: [1.0, 32.0] + + phones_range: [0, 512] + duration_range: [0.0, 64.0] + + random_utterance: 1.0 + max_prompts: 3 + prompt_duration: 6.0 + + max_resps: 1 + p_resp_append: 0.25 + + sample_type: speaker + + tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"] + + training: [] + validation: [] + noise: [] diff --git a/data/qnt.dac.pt b/data/qnt.dac.pt deleted file mode 100644 index 80b89fdec133a4fe24ee8f783d9a934968875c35..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 80456 zcmb821&~+Qwx=6+cM0wg+=Dg_f#4c~yF=sd?m>bE4;EYk1b0GkcXxsleBkE$Qg?dh z)zqXa@0aEKoU`}ZYps9ZdKJrs3LOv-K77Fc@lWJ{xB(s724!j7AgDp+_N`i%42&62 z`~Uc-U&uy5?K(DUp1EG9b~oW=%ST_{SvTx9!?CtMU&oVuz45Lwr`r+96cUP@l0;`_3Igm*1Wys6&`a=ts{cH2!|p*%+I?|XGBDZy<%$3I zSJ1jeP*Ce8UxExScq2jB&jG&t`#%SX7!dgHTg{TSWMHBHXN3PfUU2Gi$j{+E|Fh&D z2q+9)=gZst{LketZ}am%8Nfo|JFxUW;($lcaj^dghJ)V=y&mvaz~4fgo#1KYW5I0b zHAX%YoR7cd;2ZD<^n0LJ47>s61AV^K#MuXS0f&O$;4c#NHsU z&*zfg1^GSv6+mw$X#73+{T=>j>}p{54g6oxTZlXga`)*#TCRFaH7l_4g6l=k~j9VLk3Y_sJFXAAuRbQeaHtyN~7n z0bhSdS>FNp=b&?eFF?<;=iwRp^24Dw5&i+zZJy{FjDG_8LG1P-&(HeZcV*!J=hpcw zTY{W4 z!B^lR?7f!~LkD8#IiJM324ZKvOa{LU{4QWa_`l=NedqHB!1uXtVK*7O8PFM_-5&+u z>rdMCq(W1#0fG68xfddBGmzcy&TXpWzX=!Zb=`F8#8tNO^T zGt9H*;mPC|+xqtDANX@Y_mS&*k6vW3JMrcauMYfe*bhbi0lbAg z2J%^8Htc60kBEK_&TJM-9=-j(?xqeKZw*3*ub^^AKkzbXN4v;4c|;d-RM`34SrKKKA9I_kmG;4&=MQ zfyjFj=PY{m7yIG&g>N2Q4BtA^e!)Cv9rVaOM}Cs{Z-~1LzWE4tzy|!Ahj(M=d$rCw0)Hy#J!_t}PJM?w zGJYF@=4b1G2COG0a@XTIstscS{x+ev3OnaAiO_2Yy5FZ^_W|6Gy?u!Hv*>sC{n?*& zz|K5oJ#w6NR1t}9zmpleATT<5=AQwq%YEWG_Pd8d9v-{M_^$xw#jh9^x;t{ei}y+c z)=~wc@8}febBo{Jo9;I{AjN{J$(BW_y1b#4-s!P zdW+%PCmw_LzFZ1D5B;y8eV_J8qgm%x{7*sOdU2M1;h#d^d)E8O^Rx#)#v6hDCHSpC z>&{Ea3xk!Bzkp6geDhyO4F1e^K#&5xWUs z2=Fy_D?$6XRmkgN@4Yq-`T-aSy1#$K?gZ=b-a3Z=jmYCdEB73FpH;yAAovP9`_6mV zKPLW0upN5ltBUY9qrVN>{-+kS^E1z77W}v_=O;zc4`f~L<3`Y3iBkx@Dqt`CS%-9k zP7eAWI^$;`cAksO(B{1&>VdhDmmwy!fqkxcODGw_pz>-g5Md~w+G+gHxhUj zd-u2ZR6OhpfYHF@=(|svKwJNtpU=V12>O0T6E_EV9>0&#^ZYkN?p(_Flh8AN=H|Jq z%bIEj{~h}8;17py-Y<=OJN${D`SvDq>#3*6H-I(pYhPqP;=1g^?H`6>XFX}&2!dak z_hJ4qkJ%@fe|Dn(3Uq&Y@7;tynDwQ{uX%VZwDss2o^vw%_R!9O%qPxGtgBlP$M0H< zc*c2(zjdJZ>v{C7>&;_((5pp!^FR^!);Dd5>-YWy?Y(py|53n^;28W=1Fc)z5N|Dd zVew~vynudt&~xCuvK>3`!LHa_-?f4EdAj1q`qz473Vza~Z{GKLwX4W_51`iU>#=h#2Dzof&j_w{W2j$%Dk zkw?SM-@zNfcRrF2{S)w;fZ4%C*jq>YT>apG4fX}2;m>n*1NvwDdET$XHxF4Kpi$7z zvA2Gj4&Qs;Ix_=$8Nqa*`*|jI4?*+Ybo5Sx&Y99crv&|dvLNe?fFJu~>%2MWn+L3O zJwFL}Zs0$Eh4jV5IhlRB`=v4QtjDHc=<~K?9o~b!zqs(TqHi5z9~OgkRt8;hNu@06Q@Ev~4E7sdTqIVy;`_H<;d(U-+z+qDC=3w_Lakn6UhWrNdBgmg2_a3Sa z{|5YX=sRb#KWGf!JZ}E4K%6nyd;h1xuX%1D^i=eJ$6pcn?g!^@Y0)bTCPsf7`um{k zgQL(}1|0?3{8ms3Havce9$w%$M`e<$AbS2X#ZdxvL3(o z-RALy@YR3AI@;p@9<=AddUQE<)>Z!AY`*)Ob@apj9q4>!H}>W;`yTUR0P>62^}wI~ z$^_PD9aIwf0(Q>f+)t(O+rs{AC$iZ`w!Un1P@~ug*Zi_J*U=deedl{qh!X|5_j4coS?|Aq?uEVUss#THeEU=TmWc2-6Xyu}`-tZ`&4he2`i;=@ zyJSZ0J*od`_&E!D-yA}}JAM}9Hv)8dXzKy%g`LRn;;$Tj?EhQCe}Y~z{5Wrahk=vi1$8QhtPvo9!^KD($<9XPHUQ5IDt34fJA@4NfxHx1^;ykS233b}Qw`??1HUgGB# zw0UU<^akt;f&MOLKWV+?zVy93N8kIXHMIFE75?q3eDCpie(PqxpYPwfjr+>?5FfjQ zJWp@Xb7Vc{clO*JMlU?}Em)`L_6dCFSTCWye@pWm-q+qMkKp&m-=EOUh}RlAGjuWh zI-ltSJp{RZT~zq5!94h}&Z~sJ`8^0x5fg674(&^^&N@A$md zTizdKSkDRYDE@2XX9oN>U@`d4z3ro3U}xU7Z?_-IfjlMaat=})f7TUkiE{=13H19x z2SS@yZ@~WswBL0;njOCTA}sdKk((2*JL|Jf+m7C2`1!!u=&wTV9Nh2g`)`5XV$k1( zoEz1HpA$ck@&6P2TJXoRe)pm8YaINm;2!Me;>Y@=AAbBU-Yav_?~Hy1@FZ9Zzt(k! z;G0J~V&4V5htS>^`Jh)}=Q$_@ogYjA8ZRDn4g9`_4hQ;u79jr~zV)VSez&3U z`{QpG^2pdZAF(dqi~c0^mqAZO{}fmSv`@GY?D@tjD0{%9}{==@E-b>DU5r&x#grgLfg=)Z{b z6M7S%ml4B?zqy~@A#a2JCj3kXJy+7!$EA=LBfje?3BM-(o**wwoL9&vBR`AW z-znV3zro*)UVqTvo2An~D%{Tb>{Exzq>!^l*@4x8C8^QO!^E`VWuEno?&>hzC2|S3r zuKqypi8{#b$I=teduAtc_k;edgRM_ieN34$yqvyFWk1iojdHq0J=VyVx5PlB6&mYhaSciR+-zhA1&ff!|-EY=I z_A^89>-llc&=LRJz$o~M3hn##em;-fx->6(_7UFS_J`)Xb3Bjd)xO6$qvvZ0>%0v* zzu(~d0lmlVFJ{Ac?r;-24C|GS2R}0F&I&ruc#J$d@$IiY&uh@P@3HRoIt- z?>t7lPaNkN4WaEfyx(qOpAdf^iQfbLd!Tbu^Jon8JZJX7{;s|iyXpA<_xA$q-V@LJ z!a0cVuO#yO!Qv#vpZ%)8S2>sVJKM*f#ZMOe{Dyr;*5^LD4c!mp36GQ zyy&^hjlT6$5O(H`wAiO%eMO)>$JS}h;hXQQe`{iAJ>3<(^58shK5?x>&%!SNe+ui| z4Bt7MbA9W}Q0Tux?<4dL*0CEu{+<$uUK`}*txo80i4AN@PY}zq3MHFI6DE=fd|q0KR=g7~&`ogM23b-a*$R zuFpFMx(0IVUDyG>Kj&QTNB?&Li%`6Xz3*+ZGM_F)e;N9l(f9Ykobc`c zAEIv_h==}b_-|NGEBL=)KNGq?@}JSK3~oWsbL#xle&rE%-aqoKH^<>WG^pG>YoBR< z@;Cl`|5u2A3%={92K^)Jyb0bxe-iP$H@}76iQK;*`UX4e1?$C&@K2)e{pg&_dbvM- zE24Lhb@hjSjNEyR^^Ny~^WS~wPe$K5DL(w_p!w`Q^d#b~0T*Du9J(&(oGdEx-$CnL z=R+IO_x<02_Pn}3-3R8+h`g^r)|DK)M$pdp><13R&kvqM-~BR+^?gLX4*#*>M}lvi zGY$LdpmUW$*m;ln`-XE%_m9sp8h!KsPU55^ZfE2Jp{)zn8KTBjmUV@P1tfzbVgQUAYBC-)n8wmj}!QTF(x`&m8>R0^8%~Yf$@W=xaY8f6GDl zXFL4#2Y<)USC)S$orcEWBoc4eV`&lB;}19@`n zUqD;8X2VYTIP|O+vLHW*opplscmep{^WL{D;QPIeYyIZEkcf3QN54MtUlVsKwCA%i zwC~6M$nV?*c_r-cfv)c_X!BcD;(rMCeE#nK5dEXbt;+)8+b^#ouJ^O|igSYq$j!$? z@HYlOFQKnOJJ-wy{Q*1yy8o=h?AOBJ$2nze=*rkRm);KjJMyNy2hV}OuY`kd-FgZ; zAfPS(kr}<;c~0j8<}Le^2gKU~nveXw>kRfqkk1Fr=g*1L1B{E@^Su?m_pSAzb71pV zPu6Yz8He4U#7~dkYv2pyOR!rD`h3&S^Pcz#TJ6^W2-CJpZoK-=E!=!?2r*-wVWbzU=wE zh}?SH_wPQKfgkIvwCGO;+as?8T89P_Z$Ixvzi;3lK;OSpao^{F@4je{y>*mzrFqlu z>O9POnfuk>MV{dIF8&HbCxmVX9S?MUbD+obyhqSCPqu|`AG-;C^JNj_v*AC-J`r>h zr(p`??Jzx^8(lF z_uGcO-_<(K@8TS5H0#LAbEOBZ|J{%FAyavd$LQ6E{sBKD;a7#;3B4Y@ebC>6kI^$f zHiw?h`t8$7z_;GXk9{rt7C=9|h&&Z^5PCV`dycOm_j`G78>b2KP{cogpMvl!L#rPW z|54D-hMvFwB}5((w0w(U(mO%fCzY@?Xuv-Flfd4%>4ER#-l$2Yz#W5_uRX$Iw3!d zzZUq(1-~Tdy4~k{;eSW``q*EAt_=MUe`7)CI=!&>9;%N#HufW+z2}QUTNgS9^`7u| zu~@`C2~NQ-0)C^xci#oUp8TJn%T)NWUU`AsymS!x5&T(C*q(u-b1%QFCuhK=t0n%@aNpX^Y=A= z_9HKc{s`=oBlqthzD4dnc5XTV`@QJbz#jDg>&sK-nb0{gn4bGkR^AAoPY?fqhX-UoSl^sO_3h+l#@p7%uPUq|l#I|_Xix#y`n zeDlsG_^j~UlqN~(C$m) zop4`(F+lTlZ|GLohh=@~;1|OGEzo<^^STed`LGvu&M(^G_bqyRkjIDb{?1BV^M6b9 z7X^D?(XeZc+&=sWev2WW4}BLsY2T-F5P!e%99u8u!LZd9fOH&VkK$U-3R_qVK(t4*k-gzsr5( zc@tu1zVSOGLf>=eedm0jFn0c)`6vGDgYK{n=NQS6cZdHIm=gYK=p^_J2Y(K<>$mP4 zh1@w|U+nE`?Tejr{VF8;IUkHU z^IY>mTjaH{tIxVKqVIR}e0&Grd;SmXcH^ft@|y6SU-_ zeD5vKlXe&pPxLde1=lBk}VRemUr4tj~U69e!)T_rA*n z?RqQV$9jJqa_v7sdvA`0E{gx8(4+C=?;h68-s3}9=S}QtVrTvdhuk{Ueq|SWZ9wz- zC-mK)*2#ldhxKO~XzQhA*yY8a@6EZ$8{~Od=P%elgLa?Jg0`=BKG7WgI>_x07eae) zc#cZLFAByWzW3>L_-Wuzf%d+R4Bx!l75WP3-^E^LT@i!H|HPm5VhU*Uw(;C|=8>ZK z*@fSR*gb&uK0JqBUEX(Y=t9uLuqy`c#NPfOJACucv|w@US3ED~*|PYzZuPvmAFW67 zqvyTWhxK`G-?5&=pwC+XJ?pfn$g_hlK>I7}y|0PakLU2aKEt1NLL=-yfcCl0!QWzk z7rTV$U4n0Zw60l0+{@_u9rIx4-)Fn8?Ss4FPx?OkKjJSJ@}cmZPk1gbAh+*LjGgmZ z`*r6)_N|wF4%Tbk;O~3xpAp3McSGkG=8v7&uOMy?a4G0Mu7V%?<#@A;-eJMaYlM-sOR>&OdQk0wQ)1YC^VI>mfqzMYGnb^UeXoy0B?2O-T$gQUmKu5*j0qn}7Z++q4ZS+Te7rURK%YfEJ(V@RZ&-iVK zWBwWd?f&*0XT$y->nHzIVR54Z__c=TMSd9fLKJ;4j; zoyE>Mqw`1WgW8^nFx9&pJcDjj%HhyMH|gw^?Ul>|3F4|LFU%53tTjkDh;T zTAAkvi@!mj`K<$Te-FDzyjNgK{4IcvkG^wBpK~gHql4$rH*YxCod(}J(EMb6_V2*W zlZA=rd5?{~{b3R0&54@_OpTu3-Fy=d`5B(eym$>i>#=j5nI1Y7*c5y7-S7CZ9)65o zMf9u(VnY9nzV&Ko_>sVH$lXsqhx>8{>wFA4ukrn_B>r{$G{LWR&tdr9w|>W;(MyB8 zJ^sDV%pcw}=CcskSfxC(0+|fDT9qhw{i}1f6+I?FIKZo%%33&wQ^3dzC z?*{fjo&>q)#r}8+emxiV)8*k;!%rCOXW?%&^bO>r&~FNVFnYzn#qfKe|0if)WP)Tbj6qkz4=GB~Cxkyl4Kn ziQNR`&V5hfrzU*An|^#x*10XwbAP>sKF<2AyYs@YL)?4N?TObIx<7V)ckAni*mXuf z67&M>JXgMV>r3}#M)drS&xn%@f7|ib94rVw4tCBR9$;tvatZxy=rzVJ7JBy2*29O; zO9Hk)-~2EcyYOHja`RUf;>879fZF^13gh3t#{K4fVW!d$5~~-$D4zg#GXEA3>)?-+fjH`9b79XIyCOAm@sD z@U8c(r|w}G4y=y9;@H&%=fQV>*1=yD-`jmM47*L}HHE&8pIFdw ziGL8ZE{Tkub%1@Ub1CyjYxHV^K38?(T9^1cKoD{B5pOJZ)&cH+>xY!+e~o@5^xdb+ zkh{)x$n96^vkv>2B|eYmXXDTFKN1=U$cEoX`0*a}cho=NKSyr}bSWI<&>zob{`-W!b-Z&< z^YeA&m<7DS&n5VtgK)(2caAIQtwpatwD+pdZJ(|m_oMH@`JDBp`7Z!})}v*h(}Ja0 zu=Usp=tFw5T^sQ^S$Hnt&7qk_wNp@yUb$+(06X_`_0CBovV&xT|GeO zr1oQecl-F{_;sFhk#!WsZ#evzHv-YmgghzqPOu^NXR+G~oe+J`vF~>QdRMS-1|1(e z>k0e62{-Y;$%fT7g`@8yC{C|sm8EE&_OzeJxUmrRoI2wO1z+vdQ zPu4*@Pb@(k`;Jt|Kf*U}I}gbK-}>Mdbbs&&>#(l1p7lHJL2mzKJ`Yda_n`DV=n-HE zurYqj3;T)x4g9a6{XOIw^da?lKY7r%9&(>r2iiZJL%%inl6C9=`{37n?>wOn{;MJn zgf0l}y|flTp3ecqF~7R6ZlL!Hd*>)q@p~M*N6`MR5CwjI?8f7#Gqm$=|DLK9{3z&o zUzCI{jK5FlwI<#*_&(1=^!*;*@8Qt%9({q{a{O8Eo6n~s{}J2(j>C_01MinC@aN!f ziR)sW&UZs1_qollTj6^?t+)L>$@4xQf95muRz>X9n?ro>`-jNA->vu7;;$C+{yeAk zv-#FMX?=GSJ^Ssu*jtC#uX+B>ht|#3>7LK?ti$`sdCx5j_qCWxp8S6|A-~DWU zD}}$J===MO`_J=kU6BHRS+UECU+a_C$S32k270%_lgJ}N*T6mw`Wr#(1^Wx4}Lv=p10@t+kss)^n$?o;E$mBsStEc?1zZ>u^vx=-YM1*3M`FY1!#ZQSODKV znFzn$CmoRMHw?7ze>U`auo^fVTtqz2gZkZwW1SckdNTTVpxyVOd9KaKUGFaRA|bB_ z?KvI*ZQpBN&Bpr5;O`CeM*NsZ^ym9*fW}Q9p^9yMGZ~?sp_)CG@y7wM@=W*sC?@9A>SK@ddr^D|^{5cPKhTb~-wgrCy z-S75+<|ogS`}YKX&2xU&*!cN__`A@v9`QRCMsFyHTEG+5*AaeB1Ph^0gU{%T|IeUlMA^SpE3 zYUo{KJ%!Ob41EH*^^^68byr;EzoBnkQG@mR9arPm`Ha6W`uo%*>=WYeckF$SbKyq@ z&4bpjQ?v)I8|{NH;%^7*NQ+*1(0Q8s;y!V__dUP%dCrTH;5T?Zu@C#D;1%@BfE&Q4 z$h}`*LR)7!KiCC-8Fu5L^YT2gq0KAiX@74>z;k+zr=for{t)~OgYS3uehDCMTl757 z?x*kY|K3OI;oI*8;->)ovFJ|)%~P)DIDGTm8tjLH`|)Fa zXuV;->-@?6;(g&g*cST*#97SqShvL@PH*_d(Dz&w!)_wzy_FF=>*PS>&ZTxhk3(-C zeyuw%K|3dIiCz`Zx;hSe-WR!9=PCRTh4#M91pgg&o(JdftKoM7&Fjsv+eCcpQu~-k z==TT5fH%>v%sR|pnc;87ulXT1cH7Y(jNT{onjs&DyalxPzTfv-_yy7PzHmSJp6u^E zmus;5n|S6m?-%EY<^k)od&CLPdK-cIZI0Z&{43&E_xhb4VV4_ud*}wN$9_5_e4pF8 zsvi6k#7hdAC!FW3f$w+jPP~!Wn}`1(ZWH+S$^Ly*dHl3Pz5*-<|1MY;Kh6o=&&82D z-)aZl1oWOxjo;C%Gco=WL7U&a567Ty9dQGHv$6BN`TMYQXXit+v0n?CUvJ^(BYv!V zy)Q!{Pl;V=^cu2G^PhF7^Mx|V=VEU^ZQW-c>x17I=zYL09dSG-Kfw3AzDE88JL^v8 zyFKAo1FZwC=lbH${OUe4@4Q8By$cy&9)1YpD*kJ+o<{Kd@?4FPSApLfyaE23pV3PN zS}!>do`=6m$o=l_afbNUGkZ-}y1TZ1|5$HQ7+K&7>@$GBlK-*XR ziJtwa=Pe=p_V~45a$TwLSegk8`uMFJ{`;PF}!(R_SEPV5h z_rqQm{t5J4G=i>ypW^u0uV3Ufp+`ZN1&b2Td4RujC4p}}l^gr>`0sj(txC+vgGJH8+5Ci72P-q!}~jq7~Xy0ALy?F9aXo%gEee^;<|%tzmP(R&iUDaV}E_efPKDbqRdu;aQ1W8++%`-m|OVJ2xDLeKz7-|KEbY55D_; z1pK=A2~E8B&_~ek1+IhdTx1IRK)?~?x525{nLq5GqhOE&bdFMr^=t$SgVFJm1l)j~ z=h?Z6_r)0WPOy$Up!J-;U;ld!ko$X!zZ<6}{(JoGNACx)CF`)>@OK{1ZzJsAfQhlQ zKiP!e((ujG-Yf3YE$AI4-XUoB&%g6i`0Dq-?+5TG@};2nL@w3`JHUH+1o5oTtjkiO zKM(`!lA`E$AfDg#GPLvl0oXgop9eh?d;8ZQ=pM+QLRW)!pVwj?+2LD{THhu|UJkS_ zalg-mzY2T%hbzQO4W0(=hc}~t8qA8H7NGekIr53{y?>ogy07aY_qnY@hG6eKX+2yV z{l&f)=tgdvY2?ZHUL~NzqJI~3p4c6FIdRTmR|zZ#b^*Jh@A)nP zZJvLHp3m+3b|2+L|040GqMru-PS$HZy&Ac{=USgwuUpp^!{0>oozq>z-z((S^VWr) zL*K9O(filBEjfPdqgUZ?K6s7wXNPVH&VnBw`%$d#Bz*fs=Y`j>s{y6}eXeNOi$$TE zg8u!3eTe&O5%Gqyj)u^lqcp^c4*~(+&sVUs&JDo+YvP)RZ^GXI|2dAfLR%+w!(T1~ zA-67hhTQMs-y_*qEk@sdZ#T64(i>>&zy!p3fnWQ+Khb;1x`#kJN7;kja_sVe9nte% zH?Mz0ZXaa**%UwS`}4%}KI(|vI{6gzByct8eC;f8>Vxig=dlmaw{L66diCR6Xgqq> zB{A`5eI5fl>m}R*4(mz{^iW<9NlLP9{c~vNf8PHK;77)92(T;u5+L{e z#Y67j0knhm9<(kwi=F!>F8=J7CL;IyreK{nz;o!213!X;iRbTDbI|XL+S6D2=wfmtt;#|&7;fE%LjUIIOnzg+K2yp#5GUvz|MKsPUQdI@6h3}_dNPO zBEh#Fa?WdAPyo5_)%R}x=XC{CWE~S>bjEHpcnIvmdhVk)3AzmSAHWsJYat&2-?_Q{ zign;M{JVb^qwl_4kH6^X%|UMeIRk!O_|u7d9y&a<@7KB2TjW{6x%fEM=e;&$i{$SYxI-``HZp!Y#C?52RjK+jLBZzFo8;M>p4g73U=6!D$Iy3VNB zd5?L{ty^0Yw-0*Gd3=vch!X)l>lphQ&wUl-e$Q6WLs-`oX!E`H&SPGnmxgtiKdld4 zPXgrL=jow6Kb4_t5zqSV8UC+=-I33Mw(qWp{Tk4H;hfw4DK>VVo3;2$jGc4Ls_+N! z9v^|fV&4WFft~rvyl&rSU18s6zDke1=lM6{nP;=&S3lPM-=XK+KP`UL`;K*{1bq(A z-QW0)hTM9|-}ffq&vjS_`uosy;-|*mzR~(~A^r>C$NdxoJMZ}j$gAMT-<`X`w~s0X z?fu|AVjbe#pbviR*Fr=830fZ&U>!BV ztzA*%?)R$L`Fr^$=)T}qupIt;e>b25!3OwsoxfmL9<;BF&wChvzV**G^sGbuZg$ z_jl!%=-c;IMc>~Q!XWPmKNa{4|MsKl&>N0C9yp)x@(XC*_zV3T_;F5o7rys%32F4M zLT`ok{Ex-Ib>4jR=7B@Ojo4Xd#?`~wf~C>-}^BW@pE8T z1bGW+^Lk?Vap0TxqQf^oZh*F4t;zbEVb=+{>o19X2=d#|KcZ)U=)LGXXDxc(cjhI( zpL4{$_^FP6zq9?n=gB&uEP8|RpOW`uT|XV#-wUj>Jm0mji%;CA(AEVpq0YUBr5$2I)sgAR@S5qJjl zzVp17M&I8@ymvfLS+KKDjSp?SWyF09-?^gayeRzZ-~#lWd*3C#`^5gl_4+-w@Lv25 z{_fTbzV}u`{2hhA4mt#MGU7*IJ=R}ykQZ>B*m@oQ1yib~5CZ&@YYLx?(1A{{TCqZ=NylS~s0U z?=biaew$&}3EKMdJ$(1qJ!sF(RQx#4r~#b;KiQyOB2|8DGerbQ-8M`*TYN z&kBD(dgf2-ZR_$n$m^p21K1zD1UmoQ%sQRV-zKhpGC_O3CZcbj_A~xY!T$>EhMjdv zMfkf|-v{hk!~Yuo8SIuKe+KP)Y70F9x%=Gyq6vN);x{#VKSM_ZC!uGb^alR7_@9Y< zICLz~-!IME`{AcT-+i7JIwJB*(B{3w(5;B;Tyit~)!+@}_7#nJo{=q<4<3QP28;@SvHI8_B;G0bq2d1udVXeLKM=m}s|I!l@Z6ummg zonN$qE{H$R&k^j4!_R|#Zus`o)_J+m8-~A>;B@q7BX`a*40oD_!^=bqB zrC>c{u-gpX2y|XxUbc_)-c5~Na{PHt7hv!Gwgh?{&(jh7h`sZfTj&{i}FN?5V?~85Nb%tLKG_Ut$ zJul#&W!>4Ziv+zAJ@b5N^iv|wj{XbK{89|PyXaX59YxQ&(s^-D_&kbp>uidC1LV@HpyfwJ z&-&lrIij(S@yJs^`#u`Mx4sU9cHdqgj?Z-l`aXKW=OQEE8?P|5=fit;GwYa+KkKJE z(B8x6p_79f!M}Mwei!c_^WqHT_0Zo3K7?Nt+I?>QWnLe_`m9Iip>LmT9qn9a63=-N z|9^n4^BnYX^jhM74rpHW90#wjiev9Q#^2xSvo7Zb&h@>w8er$#=PLTCScmn7^Dg_~ zYuMFBf4l31KMQ;Re!@If0sUggS3)NPFT>A(zo(#mRCDC!OYceN3C_*0qW22F_IdWT z>5*IS+jj=S|ADygu(z(XKAM5vOXOqFw=aqU-+r$L@5SFC)?n}Rcz@z1z`XYZcFvLZ zVs{AtpD^(EmH5zwz+LDc2JQ2`=f=SQ4ZH|8#(o9y-S6F@+avd$D+}Mek_9>n{w}*N z*1ZE72>6x%u%6zIo%5np*qukO6!P8hy$@c%{}D6~TlbAX{}X86HJSCAciqqZ;k$n? z5Pt;xsnBEb=Uia}{BKxCA=iaoHR#3Yc`mHGn&E#aa_iE5*zE+(quy8c{q8UOPy52~ z#LocE!v6*0wTJ)M_@MReGU$QCp9#9|WYE^<&W~JAKjdMEV}D~kxq|h?K+pTG8hp>+ zP2~4jkNe+w(0KfJ0blu?*nNlH2I$Aglf%CQT5r3LIs{vvb%ph)`KJ?pJul<1Gw*s2 z1fus9@+Z*O!Fb>ho}(l?t`cb8@;on7jvx2^*U+z#FT?Lpa4r5%ATJ1g05nh7 zFaN8D-znf+?5!tnU~eC9|7rak4!bz$b%*vjL*TDJ_!9Jd{e`~w)HD28zuHe7hF^l` zbsn+=zI}xA4DZoZ#QPrHO}uv4<-=cuVE(-)_n>e8XCIyyzTaUE_T}Ig$L}9tZ1~oB z=I21xbq{%d&^Yg)?VGHx{CmOe_;+sTem>878e;D}cRlniKzb&le z8G5gw?bH3=J#Ita`_%r|x+o0t)%ZCS%%62)8uZ4p5YMgpjgec&Eya)T-+b!*@E7`1 z(Q5!b5864$QsO=1`34|2-`T&kgl~R1f&P5>)^j)Dx5aN<;=0egN8_PqU%MSY?i1&) z!?Cw7n}(hDxW5m@LN5$)tAo~q-v2&VIQ)di&ODzNyRTS}_q*pc8GOIT3;YZQ&HwN4 zV}0G6IAcKf(@f|-*t?&z!q<<#3;6!MH{3s)vA6GOh`n`+@2Mhu?@@n;a6Wv8bx-Fx z(}7pue}WDKTjK9Ew7+|-W<8s+bDmlk{cOmmfThs84%+{{M?WF{y+^F4th-yIcMN}N z@Us!SD)6(x_dO4U?u`94=!l?sS4=j?aw-)(7bfxp6;vULFzOe!38}{p@h!=fKZta3r`3ee1wv(DrW)q341Pz*xQ~{LF-Q9_DxN z2;cA9h4pO2?r;3rN2i2#Ua%f}>-U++eU3)>83o^aY9(>p@7C|;Yv)Pcrw6d_h+Rzl zb|$`kUv&81t9g*$#@_qNyyINuIeN}*y=NQ4KT6y@_|FY(zPJa!8T@J3Z6WSr=xgY? zzfPiWKKERE?pGl%qCe33q%8iOdmMsxKOH09M(jEwp9fm!+Q0OMZ@pySejL8_VH^DJ z#P20&?|tih@9hxCYlG&6(mbDe+`R0(`7r(~W9PlRll8<#UL1NodUfF^g>S#%?;z$y z`;^k?wPro`Yv-ZIBM$|7PI?k=19%SkX)p`+3qa?gkI;Ah*24$kNAl~d%nkyb9}$=bnFUXXFeH={$I#b;?MiUIhOYK zkRLP2Njx_+h}r_&ETb4>X_JA9#P8557k)2KMGz@7t#EqZ0o^ z7kl%Ta|-Kz-IadCr$+X`P^IKdv1H79|U@DFF-Fl_SRh`p*IsZ7PNEp zg78yv_J9BfX+n@qkk5A&zp6pzsowm z%ZZ+Km*49>eh#rt=OpHL?}KH?>!EMG5t4QJyVgG9?1w)AJI}X$yY+ziF#&eg%b~Dq z4d4FjIPv{XRggP}uZrCX^z1X1!Visp1>~NiX2|W=?0+|*@B5F6-9?_yddT|Md!ik3 z_v0mK^RRuG@oEsK8@LDzBHnxG6s*&}C=dLD`1gHIhJO)z|4!*=_!;nHK5$-Pon&3J zCs@3A#IY|n|8B&u_5KI++7Uk?@{agh4Bxq*b*lH0`NiKwTA;rT^zX&|{bUz@?;^LK z@V?5%dj3SO33}H+&#n9a0_%0YnIF9y;C1BQcmBSS8oA#q9Pv(qdsv5cg7ddP{8@hw zMz1?|<|&`ob75cG6Z?q7w;pu9=zjYXy|VZ(gnoW#-=lp;NaAlsZw>Z-->2|jV{bm+ zkDc?7o#@$D*r#{LZVhtlB=d9|{ER|ghxqR6y{;SiQskweqeFYYWyGKFu`BkeLHA>G z>>A)_C1{@c4f$j2W`N#b)_2xlCy_@2okQE#EW-apFfn@GGvirzdN4oe`?N0c{KsK^ zGlJ>+{eMLM40&4OdcT_YRwB2qiU;lQD1Wlf;n*+3ZzS}*->lD0BDY_2PL~6F^O*f$ zMf~(2PE%05P1w0U`)lhG^R4@9I`KU(J@7voel`46g7%yyM$fqZ9#)L?ya4ULtK!G! zb1qQ;zah~76?E>Eoj41z{}%oD;6*So@`%_M#;zRL59Gqye;s%q`E2ZC zK^FjfqgU3rU@_$H!JpAH@5F_l5xD{vtq|Z=17j``v$^7rivZabM4aeun>I&~KsbXS`=~;J+jC@9=BCkjM9j zd?vUOenb3chHpQ96~29KPvjloyC2QB8SrQSfm(pi@BG4i=$vOHik<`OMxWyq&(jIL z1gyik(NDz9jNE!_HhSm4!T5OvJrjBYwDVK@%K`XXie7T)-_iGb`#rrcUZeL8EQkJ9 z(7IpwdjkE5@ay6~2mFtqby{`oZu*^&n-6_o3Gg=`y$awC z`17zE1|0|5eLIzPJcU0JoCALc=sjSax&nXi(63Bf`xD=PAog95e}kR%>n-fvU+1A~ zWA_cTf7eh3eq#KT1s|bjy|$Qm)@RqClcQ(fYky(>>5Kp6==;7K6UVyF^V%DI>)k}q zPkc@=8Q2@YQ(4Da=sn0QVSgC;Sa3RemC@Fbx27Wkb*K-=Z)A;R-+_~Bg z?CgskK-<^GgWinY9sKOXZw=6XE*5&1(Tf0G63KBa{ROft$Q9qn->OQ|1JLg zT{|YUb>txIoL3cvwyyDg9>>n}Z@sV=etqnGKI>!qP3!mqtZO@VO+g+zzA~o$*k3_^E%HhDUk#lXJ@e8* z_}Ady!2U7(RPblPj|^r8Jzt)Se(1-+z9F>r^e60`OKd^UImsd9hp;oRoBz$L_FvXL zchSEOx^E*x`+grm&mitta2Ro|+l=pbtA$-j^n70-;5+v!jodoe``-KKXTHZU>~0YM z0Q{uTN%3+cZO$qBKahkhDppVz*l zEb9%4-24{<`XcE5kBvY3^2DtF1nBcV#cmvE9b=tlp4)(3QuNA0e+zm~G+>=S!4C`V zJi+hkKDmrM0qeD2y8z#P?)$d?9Eg4x)>9C=GjZEu{}K7uV0`tUy?=j0e-3=#+gbSL zaqA-M80%s0i6Qu31ttL#u`cUi_nGz0AoR?WFQA`;>0AeR1%LLl$>0|Try*|+y%|gk z`hJq}yx+ht1igj0mEre>_V;b;2m5vFo_@rgiXXqnZuqmn;>e3YCjpOuxsY2=n3wHi zqaZi`26&&CSL^Z|Te0&T*2dr>Xx-?&_AHqGO#I9PpMdtU<(J>PYOpKm0w-zTNxM`?ex+eV;!=t(0=}V*3kgHSI}dj zecyr5qgapo+5FQBzrK&W=$Xgt`~1DgdT<4P%ZS8(jlK1a?_nf*_G#WTRp9>sz6Z^} zjaXkx(7DNP#Iv6oirql;60(l8(AItKw6EEet|#lgAefCH!ty* z2|eqrZ1~-U{!!59k$(;QirD)dmZ5(S{xb9yK;MTB2Yr@!Kg0L_a9@9o-U8&pQ2rx4 zI2yg&(9X{)vmW!2^K@d$6wIp#7qAthuc3 zHFCe_PUxA$F{jC!GQc{uI{ZtZ^`XDl6vY7XzyJTbd$O+W#BYQ` zRnT*u8rnMAJd+Z?&gZJ*uN>$--v&Q6pf5EOA=>1(7{onC33)=f2GWN~E zu-HdOFC}z3=(EtiFMn6qi2ql}&)}~T=sj>AJ%9iCx1Yo=D|iq)=X%}`zQ4)nd9QB8 zZ(-~@gE_&K;Ck$yp+6Ulj(m&pdET60a`eoDVbJ$|)Z9GP&U(eXYMoFXJ@1bK*n6Id@thyc!B2MNrO}_EALO5)y@&Rop9=Xz@CW$r$63(U z+1B;Hz<-Y42G~!7UyyZL?>GnAjh^rSZ|EA>rGVayeQf-9fL{T=zqdN~+y_4a`VY}} zUhG`l`>Zv3{+(kS^sR>*BOird@1+r})45p{XzO0TzxB0!f_cyNIRA-`~|-)*l%<1^W6=>3XI9ClquM2jW(6cY|{+)o__ty#9yn2uL?oa0i z{vLJ;d3WN~AddZdawhmM>Zn2j~fboo#(TjwLi5VbiQ*FedoXPvGYFe3hjBw$^z8i3~imAf_1vj zx+=#{Eoj%3nE2lBzW1c?7o*=Fzs^T1LhnbP&qqS!=h3%s?Tp{W@b5xjg$~bp z_3wP_3GxA;b(?AyAW@w*L$cE1&d{u^|C0r=~Qzstx&fQi8c=y~41V|`(; zvz}Ulp80e?wDT+Pw+HZjf9de|8ouWKQoBqzAg`48b6*-`+|eWXXCFWSPFg}?5s1K z(>doKi98>e1RR7P>+-75&OZvGR~Y^_a1A&N+z9UDJ-L6q?|eVj!(-5E$ol;*?nCd1 zJIKvL*Rb>cdByrN6aNu-3OoB;`|1hEef}WmqWC=r?VKnqe#~!k(epf5XN|+oyn7hh zb9)Ux9kH(fR#6W$UQ=l6#535}#ZLd3cz)-f<`d_=%kXm=EP)^Ip|a3R@Us*DjX}Tf zPv|8F52Akoz3kBXwV#vjhr9{vur9Q2_4irp>-xmGhW>BFsRr%5Vj#5h9q)~s@a+>q z5Z`-nF8(wb6()@9n0`r z8T&$@_3}5or`T;-1BCapskVSjlJ_~`)BW=71)`-y?@KV_a5zvAN!D(#P_}{K)e>{ zc~1o3XCnNqpmljJ{J6hsp}z<{>%Y?YKM8+7wCnet*oplZ^gO55wax|I$JQy+@D~Mq zz`ERLzK?attAggU66o9Cna|AkewX;H&wADOVSMjt>%my~Jx9FIpy%NR^mzRC0J~u4 zd?P=4b>MqH?}J~D_vM_WFMeWyFTnKJuf@*3>38JM;8%dIhkjD%_{6CJe+c@OklWYx zg8vTv%wP`mtn2m!qa$GVJNg%qd!NTAuHSP#`u^_{a-%;MjDvmy(EaAV-c6ji;3fQ- zhs=jf;A{61IuU-YFLJ&Jt))NW6=B)EC__?7!fWzSL0E^(q=Zp(&U7ZiQ81}xO?ZlgjJV7u!^Jr=K z$Iy3Q6(N2yxZ~-0K24(5e zq+_QR?b?8qjcI_Myzj!z{t68-#0lxhEKYK_Rko!L#GB<7z)F4ZzCLLQeXx*Ym zgP_me>a}QGGBEoWhyS1dWX+W)civoi^XAHvBX6#J`EqB^lQ(GHtl4wq z%$qf5zC1bdeOcxA`F|Pyr9S%dZ^d%|d 0: + for job in tqdm(txts, desc=f"Phonemizing: {speaker_id}"): + outpath, text, language = job + phones = valle_phonemize(text) + data = { + "text": text.strip(), + "phonemes": phones, + "language": language, + } + open(_replace_file_extension(outpath, ".json"), 'w', encoding='utf-8').write(json.dumps(data)) + + if len(wavs) > 0: + for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): + try: + outpath, waveform, sample_rate = job + qnt = valle_quantize(waveform, sr=sample_rate, device=device) + qnt.save(_replace_file_extension(outpath, ".dac")) + except Exception as e: + print(f"Failed to quantize: {outpath}:", e) + continue + +open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) diff --git a/scripts/train_tokenizer.py b/scripts/train_tokenizer.py new file mode 100644 index 0000000..6b4058a --- /dev/null +++ b/scripts/train_tokenizer.py @@ -0,0 +1,57 @@ +import os +import json +import torch +import torchaudio + +from tqdm.auto import tqdm +from pathlib import Path + +from tokenizers import Tokenizer +from tokenizers.models import BPE, Unigram, WordLevel, WordPiece +from tokenizers.trainers import BpeTrainer +from tokenizers.pre_tokenizers import Whitespace +from tokenizers.processors import TemplateProcessing + +input_metadata = "training-24K" + +output_file = Path("./dataset.json") +tokenizer_data = [] + +def pad(num, zeroes): + return str(num).zfill(zeroes+1) + +if output_file.exists(): + tokenizer_data = json.loads(open(str(output_file), "r", encoding="utf-8").read()) +else: + for dataset_name in os.listdir(f'./{input_metadata}/'): + if not os.path.isdir(f'./{input_metadata}/{dataset_name}/'): + continue + + for speaker_id in tqdm(os.listdir(f'./{input_metadata}/{dataset_name}/'), desc="Processing speaker"): + if not os.path.isdir(f'./{input_metadata}/{dataset_name}/{speaker_id}'): + continue + + for id in os.listdir(f'./{input_metadata}/{dataset_name}/{speaker_id}/'): + if ".json" not in id: + continue + + metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/{id}') + metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) + + tokenizer_data.append( f'{"".join(metadata["phonemes"])}' ) + + open(output_file, 'w', encoding='utf-8').write(json.dumps(tokenizer_data)) + +unk_token = "" +spl_tokens = ["", "", unk_token, ""] + +trainer = BpeTrainer(special_tokens = spl_tokens, vocab_size = 256) +tokenizer = Tokenizer(BPE(unk_token = unk_token)) +tokenizer.pre_tokenizer = Whitespace() +tokenizer.post_processor = TemplateProcessing( + single=" $A ", + special_tokens=[("", 1), ("", 2)], +) + +tokenizer.train_from_iterator(tokenizer_data, trainer=trainer) +tokenizer.save("./tokenizer.json") \ No newline at end of file diff --git a/vall_e/config.py b/vall_e/config.py index 1081e77..e5be02c 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -18,6 +18,9 @@ from omegaconf import OmegaConf from .utils.distributed import world_size +# Yuck +from transformers import PreTrainedTokenizerFast + @dataclass() class _Config: cfg_path: str | None = None @@ -540,10 +543,12 @@ class Config(_Config): inference: Inference = field(default_factory=lambda: Inference) bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) + tokenizer: str = "./tokenizer.json" + fp8: FP8 = field(default_factory=lambda: FP8) sample_rate: int = 24_000 - variable_sample_rate: bool = False + variable_sample_rate: bool = True @property def distributed(self): @@ -611,16 +616,19 @@ cfg = Config.from_cli() # OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves try: cfg.format() - - # cached_property stopped working... if cfg.dataset.use_hdf5: cfg.load_hdf5() - - except Exception as e: - print(e) + print("Error while parsing config YAML:", e) pass +try: + from transformers import PreTrainedTokenizerFast + cfg.tokenizer = (cfg.relpath if cfg.cfg_path is not None else Path("./data/")) / cfg.tokenizer + cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(cfg.tokenizer)) +except Exception as e: + print("Error while parsing tokenizer:", e) + pass if __name__ == "__main__": print(cfg) diff --git a/vall_e/data.py b/vall_e/data.py index fc49975..434b3d3 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -24,17 +24,17 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset as _Dataset from torch.utils.data.distributed import DistributedSampler from tqdm.auto import tqdm - # torch.multiprocessing.set_sharing_strategy("file_system") _logger = logging.getLogger(__name__) # to-do: clean up this symmap mess def get_phone_symmap(): - if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: - return json.loads( cfg.hdf5['symmap'].asstr()[()] ) + return cfg.tokenizer.get_vocab() - return {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, 'ᵝ': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, 'oˌ': 220, 'eˈ': 221, 'ʍ': 222, 'eˌ': 223, 'uˌ': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228} +def tokenize( phones ): + return tokenizer.encode( "".join(phones) ) + #return [*map(get_phone_symmap.get, _get_phones(path))] def get_lang_symmap(): return { @@ -178,7 +178,9 @@ def _get_phones(path, language="en"): else: content = open(_get_phone_path(path), "r", encoding="utf-8").read().split(" ") content = _cleanup_phones( content ) - return [""] + [ " " if not p else p for p in content ] + [""] + + return "".join(content) + #return [""] + [ " " if not p else p for p in content ] + [""] def _interleaved_reorder(l, fn): groups = defaultdict(list) @@ -435,7 +437,7 @@ class Dataset(_Dataset): text = torch.from_numpy(text).to(self.text_dtype) resps = torch.from_numpy(resps).to(torch.int16) else: - text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) + text = torch.tensor(tokenize( _get_phones( path ) )).to(self.text_dtype) resps = _load_quants(path) lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8) @@ -847,18 +849,21 @@ def create_dataset_hdf5( skip_existing=True ): # audio if audios: qnt = np.load(f'{root}/{name}/{id}{_get_quant_extension()}', allow_pickle=True)[()] - codes = torch.from_numpy(qnt["codes"].astype(int))[0].t() + codes = torch.from_numpy(qnt["codes"].astype(int))[0].t().to(dtype=torch.int16) if _get_quant_extension() == ".dac": if "audio" in group: del group["audio"] duration = qnt["metadata"]["original_length"] / qnt["metadata"]["sample_rate"] - metadata[id]["metadata"] = qnt["metadata"] + metadata[id]["metadata"] = { + "original_length": qnt["metadata"]["original_length"], + "sample_rate": qnt["metadata"]["sample_rate"], + } else: qnt = torch.load(f'{root}/{name}/{id}{_get_quant_extension()}')[0].t() duration = qnt.shape[0] / 75 - group.create_dataset('audio', data=qnt.numpy(), compression='lzf') + group.create_dataset('audio', data=qnt.numpy().astype(np.int16), compression='lzf') group.attrs['duration'] = duration metadata[id]["duration"] = duration @@ -869,17 +874,22 @@ def create_dataset_hdf5( skip_existing=True ): # text if texts: if _get_quant_extension() == ".json": - j_son = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) - content = j_son["phonemes"] + json_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) + content = json_metadata["phonemes"] else: content = open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read().split(" ") + """ phones = [f""] + [ " " if not p else p for p in content ] + [f""] for s in set(phones): if s not in symmap: symmap[s] = len(symmap.keys()) phn = [ symmap[s] for s in phones ] + """ + + phn = cfg.tokenizer.encode("".join(content)) + phn = np.array(phn).astype(np.uint8) if "text" in group: del group["text"] diff --git a/vall_e/inference.py b/vall_e/inference.py index e7c32b2..d14d38a 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -91,15 +91,8 @@ class TTS(): return text content = g2p.encode(text, language=language) - content = _cleanup_phones( content ) - # ick - try: - phones = [""] + [ " " if not p else p for p in content ] + [""] - return torch.tensor([*map(self.symmap.get, phones)]) - except Exception as e: - pass - phones = [ " " if not p else p for p in content ] - return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ]) + + return torch.tensor(cfg.tokenizer.encode( "".join( content ) )) def encode_lang( self, language ): symmap = get_lang_symmap()