From d7b63d2f709faeef5bf88fe77a4d3074cc16aa50 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 18 Jun 2024 10:30:54 -0500 Subject: [PATCH] encoding mel tokens + dataset preparation --- data/mel_norms.pth | Bin 0 -> 1067 bytes data/tokenizer.json | 1 + scripts/parse_ppp.py | 96 ----------------------- scripts/process_dataset.py | 72 +++-------------- scripts/process_libritts.py | 56 ++------------ setup.py | 3 + tortoise_tts/config.py | 3 +- tortoise_tts/data.py | 123 +++++++++++++++--------------- tortoise_tts/emb/mel.py | 45 +++++++---- tortoise_tts/models/__init__.py | 27 ++++++- tortoise_tts/models/arch_utils.py | 43 ++++++++++- tortoise_tts/models/dvae.py | 2 +- tortoise_tts/train.py | 5 +- tortoise_tts/version.py | 2 +- 14 files changed, 182 insertions(+), 296 deletions(-) create mode 100644 data/mel_norms.pth create mode 100644 data/tokenizer.json delete mode 100644 scripts/parse_ppp.py diff --git a/data/mel_norms.pth b/data/mel_norms.pth new file mode 100644 index 0000000000000000000000000000000000000000..d8c73216492c6e3a58ea9ca74beb3c263dfc452e GIT binary patch literal 1067 zcmb7DZAepL6uxurqwB|}OZ_1RW#v~+vk>}G=G1}QqG=yk5o6t5Y&o2FyE8H>@e3;Y zq7Nj&{E<;oXhn&n?RhOo%M3{oi69g#MWr$lB(i&%%C_)F=jG*a&V8PjbMAAFiIE3W zlv+*wV_GVbvN+jewsJ1bY7s2@vJ$(|5KbNPv&tY=aEK_q*)7=YF52wmj=F7jt6AVm zT@I(YJcH(`#Ka<%_`XX-T8cSd=B3a^yTc;v^(JXPs7NMa(s_B2T=Z<2n-6A80wlKlST?o!#hv)tr^ek z#~7T7i$)*=(@8nW@DEe|r|VVuY=!OWQ|^uE0W!$Lj6wl|?7_6{l@ z-2*%D0M7YVYk#&^$f$uJCI=L#KW2wX#V9v#5^$3K~Qb&Mo`B~obP&t z3sbLAwe1adi|+NegM-oi1i1Duwi--NBRa)-8F@Iet=E!W zYy`md$D(9RoaRQ8ntpAWIKsj%?g_bH9Xe{0zB+iL+O^i%(%RAvaM_TEFp9OA6AltXM za6jg_82Pf((gV4^RcZXVSHudJxQqY7#3+@U7C~#|smpL_nIv2C?cR`zTug{OX*a3f T#aLQG6QfW{Ik|WwychcylPpr# literal 0 HcmV?d00001 diff --git a/data/tokenizer.json b/data/tokenizer.json new file mode 100644 index 0000000..a128f27 --- /dev/null +++ b/data/tokenizer.json @@ -0,0 +1 @@ +{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}} \ No newline at end of file diff --git a/scripts/parse_ppp.py b/scripts/parse_ppp.py deleted file mode 100644 index 51f2300..0000000 --- a/scripts/parse_ppp.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import json -import torch - -from tqdm.auto import tqdm -from pathlib import Path -from vall_e.emb.g2p import encode as valle_phonemize -from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension - -device = "cuda" - -target = "in" - -audio_map = {} -text_map = {} - -data = {} - -for season in os.listdir(f"./{target}/"): - if not os.path.isdir(f"./{target}/{season}/"): - continue - - for episode in os.listdir(f"./{target}/{season}/"): - if not os.path.isdir(f"./{target}/{season}/{episode}/"): - continue - - for filename in os.listdir(f"./{target}/{season}/{episode}/"): - path = f'./{target}/{season}/{episode}/{filename}' - attrs = filename.split("_") - timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s' - - key = f'{episode}_{timestamp}' - - if filename[-5:] == ".flac": - name = attrs[3] - emotion = attrs[4] - quality = attrs[5] - - audio_map[key] = { - "path": path, - 'episode': episode, - "name": name, - "emotion": emotion, - "quality": quality, - "timestamp": timestamp, - } - - elif filename[-4:] == ".txt": - text_map[key] = open(path, encoding="utf-8").read() -txts = {} -wavs = [] - -for key, entry in audio_map.items(): - path = entry['path'] - name = entry['name'] - emotion = entry['emotion'] - quality = entry['quality'] - episode = entry['episode'] - path = entry['path'] - timestamp = entry['timestamp'] - transcription = text_map[key] - if name not in data: - data[name] = {} - os.makedirs(f'./training/{name}/', exist_ok=True) - os.makedirs(f'./voices/{name}/', exist_ok=True) - - key = f'{episode}_{timestamp}.flac' - os.rename(path, f'./voices/{name}/{key}') - - data[name][key] = { - "segments": [], - "language": "en", - "text": transcription, - "misc": { - "emotion": emotion, - "quality": quality, - "timestamp": timestamp, - "episode": episode, - } - } - - path = f'./voices/{name}/{key}' - txts[path] = transcription - wavs.append(Path(path)) - -for name in data.keys(): - open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) ) - -for key, text in tqdm(txts.items(), desc="Phonemizing..."): - path = Path(key) - phones = valle_phonemize(text) - open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones)) - -for path in tqdm(wavs, desc="Quantizing..."): - qnt = valle_quantize(path, device=device) - torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt")) \ No newline at end of file diff --git a/scripts/process_dataset.py b/scripts/process_dataset.py index d2ae337..e0ded3b 100644 --- a/scripts/process_dataset.py +++ b/scripts/process_dataset.py @@ -5,7 +5,7 @@ import torchaudio import numpy as np from tqdm.auto import tqdm from pathlib import Path -from vall_e.config import cfg +from tortoise_tts.config import cfg # things that could be args cfg.sample_rate = 24_000 @@ -16,15 +16,14 @@ cfg.inference.dtype = torch.bfloat16 cfg.inference.amp = True """ -from vall_e.emb.g2p import encode as valle_phonemize -from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension +from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension input_audio = "voices" input_metadata = "metadata" output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}" device = "cuda" -audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc" +audio_extension = ".mel" slice = "auto" missing = { @@ -57,30 +56,8 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): continue waveform, sample_rate = torchaudio.load(inpath) - qnt = valle_quantize(waveform, sr=sample_rate, device=device) - - if cfg.inference.audio_backend == "dac": - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.original_length, - "sample_rate": qnt.sample_rate, - - "input_db": qnt.input_db.cpu().numpy().astype(np.float32), - "chunk_length": qnt.chunk_length, - "channels": qnt.channels, - "padding": qnt.padding, - "dac_version": "1.0.0", - }, - }) - else: - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": waveform.shape[-1], - "sample_rate": sample_rate, - }, - }) + mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device) + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel ) continue @@ -177,45 +154,16 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')): )) if len(wavs) > 0: - for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): + for job in tqdm(wavs, desc=f"Encoding: {speaker_id}"): try: outpath, text, language, waveform, sample_rate = job phones = valle_phonemize(text) - qnt = valle_quantize(waveform, sr=sample_rate, device=device) - - if cfg.inference.audio_backend == "dac": - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.original_length, - "sample_rate": qnt.sample_rate, - - "input_db": qnt.input_db.cpu().numpy().astype(np.float32), - "chunk_length": qnt.chunk_length, - "channels": qnt.channels, - "padding": qnt.padding, - "dac_version": "1.0.0", - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, - }, - }) - else: - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": waveform.shape[-1], - "sample_rate": sample_rate, - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, - }, - }) + mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device) + mel["text"] = text + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel) except Exception as e: - print(f"Failed to quantize: {outpath}:", e) + print(f"Failed to encode: {outpath}:", e) continue open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) diff --git a/scripts/process_libritts.py b/scripts/process_libritts.py index 53ce053..c4907df 100755 --- a/scripts/process_libritts.py +++ b/scripts/process_libritts.py @@ -18,9 +18,9 @@ cfg.inference.amp = True """ from vall_e.emb.g2p import encode as valle_phonemize -from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension +from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension -audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc" +audio_extension = ".mel" input_dataset = "LibriTTS_R" output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz" @@ -56,52 +56,10 @@ for dataset_name in os.listdir(f'./{input_dataset}/'): for paths in tqdm(txts, desc="Processing..."): inpath, outpath = paths try: - if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists(): - data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read()) - qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True) - - if not isinstance(data["phonemes"], str): - data["phonemes"] = "".join(data["phonemes"]) - - for k, v in data.items(): - qnt[()]['metadata'][k] = v - - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt) - else: - text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read() - - phones = valle_phonemize(text) - qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device) - - if cfg.inference.audio_backend == "dac": - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.original_length, - "sample_rate": qnt.sample_rate, - - "input_db": qnt.input_db.cpu().numpy().astype(np.float32), - "chunk_length": qnt.chunk_length, - "channels": qnt.channels, - "padding": qnt.padding, - "dac_version": "1.0.0", - - "text": text.strip(), - "phonemes": "".join(phones), - "language": "en", - }, - }) - else: - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.shape[-1] / 75.0, - "sample_rate": cfg.sample_rate, - - "text": text.strip(), - "phonemes": "".join(phones), - "language": "en", - }, - }) + text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read() + + mel = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device) + mel["text"] = text + np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel) except Exception as e: tqdm.write(f"Failed to process: {paths}: {e}") diff --git a/setup.py b/setup.py index 84add5c..2350f5c 100755 --- a/setup.py +++ b/setup.py @@ -54,6 +54,9 @@ setup( "tokenizers", "transformers", + # + "rotary_embedding_torch", + # training bloat "auraloss[all]", # [all] is needed for MelSTFTLoss "h5py", diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py index 8a8efcd..8c3b937 100755 --- a/tortoise_tts/config.py +++ b/tortoise_tts/config.py @@ -502,7 +502,7 @@ class Optimizations: optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation) bitsandbytes: bool = False # use bitsandbytes - dadaptation: bool = True # use dadaptation optimizer + dadaptation: bool = False # use dadaptation optimizer bitnet: bool = False # use bitnet fp8: bool = False # use fp8 @@ -525,6 +525,7 @@ class Config(BaseConfig): tokenizer: str = "./tokenizer.json" sample_rate: int = 24_000 + audio_backend: str = "mel" @property def model(self): diff --git a/tortoise_tts/data.py b/tortoise_tts/data.py index 718294a..3a911a3 100755 --- a/tortoise_tts/data.py +++ b/tortoise_tts/data.py @@ -169,8 +169,10 @@ def _get_paths_of_extensions( path, extensions=_get_mel_extension(), validate=Fa def _load_mels(path, return_metadata=False) -> Tensor: mel = np.load(_get_mel_path(path), allow_pickle=True)[()] if return_metadata: - return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16), mel["metadata"] - return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16) + mel["metadata"]["text"] = mel["text"] + + return mel["codes"].to(torch.int16), mel["metadata"] + return mel["codes"].to(torch.int16) # prune consecutive spaces def _cleanup_phones( phones, targets=[" "]): @@ -453,37 +455,19 @@ class Dataset(_Dataset): ) """ - prom_length = 0 - trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) + path = random.choice(choices) + if cfg.dataset.use_hdf5: + key = _get_hdf5_path(path) - for _ in range(cfg.dataset.max_prompts): - path = random.choice(choices) - if cfg.dataset.use_hdf5: - key = _get_hdf5_path(path) + if "audio" not in cfg.hdf5[key]: + _logger.warning(f'MISSING AUDIO: {key}') + return - if "audio" not in cfg.hdf5[key]: - _logger.warning(f'MISSING AUDIO: {key}') - continue - - mel = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16) - else: - mel = _load_mels(path, return_metadata=False) - - if 0 < trim_length and trim_length < mel.shape[0]: - mel = trim( mel, trim_length ) - - prom_list.append(mel) - prom_length += mel.shape[0] - - if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance: - break - - # might be better to decode => concat waveforms with silence in between => reencode - # as you technically can't just append encodec sequences together like this without issues - prom = torch.cat(prom_list) - - if 0 < trim_length and trim_length < prom.shape[0]: - prom = trim( prom, trim_length ) + # audio / cond / latents + # parameter names and documentation are weird + prom = torch.from_numpy(cfg.hdf5[key]["cond"]).to(torch.int16) + else: + prom = _load_mels(path, return_metadata=False) return prom @@ -507,15 +491,36 @@ class Dataset(_Dataset): spkr_group = self.get_speaker_group(path) #spkr_group_id = self.spkr_group_symmap[spkr_group] + if cfg.dataset.use_hdf5: + key = _get_hdf5_path(path) + + if key not in cfg.hdf5: + raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}') + + text = cfg.hdf5[key]["text"][:] + mel = cfg.hdf5[key]["audio"][:] + latents = cfg.hdf5[key]["latents"][:] + + text = torch.from_numpy(text).to(self.text_dtype) + mel = torch.from_numpy(mel).to(torch.int16) + latents = torch.from_numpy(latents) + wav_length = cfg.hdf5[key].attrs["wav_length"] + else: + mel, metadata = _load_mels(path, return_metadata=True) + text = torch.tensor(metadata["text"]).to(self.text_dtype) + latents = torch.from_numpy(metadata["latent"][0]) + wav_length = metadata["wav_length"] return dict( index=index, path=Path(path), spkr_name=spkr_name, spkr_id=spkr_id, - #text=text, - #proms=proms, - #resps=resps, + + latents=latents, + text=text, + mel=mel, + wav_length=wav_length, ) def head_(self, n): @@ -603,12 +608,14 @@ def create_train_val_dataloader(): return train_dl, subtrain_dl, val_dl def unpack_audio( npz ): - mel = torch.from_numpy(npz["codes"].astype(int))[0].t().to(dtype=torch.int16) + mel = npz["codes"].to(dtype=torch.int16, device="cpu") + conds = npz["conds"][0].to(dtype=torch.int16, device="cpu") + latent = npz["latent"][0].to(dtype=torch.int16, device="cpu") metadata = {} - if "text" in npz["metadata"]: - metadata["text"] = npz["metadata"]["text"] + if "text" in npz: + metadata["text"] = npz["text"] if "phonemes" in npz["metadata"]: metadata["phonemes"] = npz["metadata"]["phonemes"] @@ -616,10 +623,15 @@ def unpack_audio( npz ): if "language" in npz["metadata"]: metadata["language"] = npz["metadata"]["language"] - if "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]: + if "original_length" in npz["metadata"]: + metadata["wav_length"] = npz["metadata"]["original_length"] + + if "duration" in npz["metadata"]: + metadata["duration"] = npz["metadata"]["duration"] + elif "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]: metadata["duration"] = npz["metadata"]["original_length"] / npz["metadata"]["sample_rate"] - return mel, metadata + return mel, conds, latent, metadata # parse dataset into better to sample metadata def create_dataset_metadata( skip_existing=True ): @@ -672,7 +684,7 @@ def create_dataset_metadata( skip_existing=True ): if audios: # ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()] - mel, utterance_metadata = unpack_audio( npz ) + mel, conds, latents, utterance_metadata = unpack_audio( npz ) # text if texts and text_exists and not utterance_metadata: utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) @@ -755,10 +767,16 @@ def create_dataset_hdf5( skip_existing=True ): # audio if audios: npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()] - mel, utterance_metadata = unpack_audio( npz ) + mel, conds, latents, utterance_metadata = unpack_audio( npz ) if "audio" not in group: group.create_dataset('audio', data=mel.numpy().astype(np.int16), compression='lzf') + + if "conds" not in group: + group.create_dataset('conds', data=conds.numpy().astype(np.int16), compression='lzf') + + if "latents" not in group: + group.create_dataset('latents', data=latents.numpy().astype(np.int16), compression='lzf') # text if texts: @@ -778,6 +796,7 @@ def create_dataset_hdf5( skip_existing=True ): except Exception as e: tqdm.write(f'Error while processing {id}: {e}') + raise e with open(str(metadata_path), "w", encoding="utf-8") as f: f.write( json.dumps( metadata ) ) @@ -837,27 +856,9 @@ if __name__ == "__main__": samples = { "training": [ next(iter(train_dl)), next(iter(train_dl)) ], - "evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ], - "validation": [ next(iter(val_dl)), next(iter(val_dl)) ], + #"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ], + #"validation": [ next(iter(val_dl)), next(iter(val_dl)) ], } - - Path("./data/sample-test/").mkdir(parents=True, exist_ok=True) - - for k, v in samples.items(): - for i in range(len(v)): - for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."): - """ - try: - decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" ) - except Exception as e: - print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e)) - try: - decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" ) - except Exception as e: - print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e)) - """ - v[i]['proms'][j] = v[i]['proms'][j].shape - v[i]['resps'][j] = v[i]['resps'][j].shape for k, v in samples.items(): for i in range(len(v)): diff --git a/tortoise_tts/emb/mel.py b/tortoise_tts/emb/mel.py index 2fcb948..6af96a1 100755 --- a/tortoise_tts/emb/mel.py +++ b/tortoise_tts/emb/mel.py @@ -15,6 +15,19 @@ from tqdm import tqdm from ..models import load_model, unload_model +import torch.nn.functional as F + +def pad_or_truncate(t, length): + """ + Utility function for forcing to have the specified sequence length, whether by clipping it or padding it with 0s. + """ + if t.shape[-1] == length: + return t + elif t.shape[-1] < length: + return F.pad(t, (0, length-t.shape[-1])) + else: + return t[..., :length] + # decodes mel spectrogram into a wav @torch.inference_mode() def decode(codes: Tensor, device="cuda"): @@ -34,30 +47,30 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"): def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) -def format_autoregressive_conditioning( wav, cond_length=132300, device ): +def format_autoregressive_conditioning( wav, cond_length=132300, device="cuda" ): """ Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. """ model = load_model("tms", device=device) - gap = wav.shape[-1] - cond_length - - if gap < 0: - wav = F.pad(wav, pad=(0, abs(gap))) - elif gap > 0: - rand_start = random.randint(0, gap) - wav = wav[:, rand_start:rand_start + cond_length] + if cond_length > 0: + gap = wav.shape[-1] - cond_length + if gap < 0: + wav = F.pad(wav, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + wav = wav[:, rand_start:rand_start + cond_length] mel_clip = model(wav.unsqueeze(0)).squeeze(0) # ??? return mel_clip.unsqueeze(0).to(device) # ??? def format_diffusion_conditioning( sample, device, do_normalization=False ): - model = load_model("stft", device=device) + model = load_model("stft", device=device, sr=24_000) sample = torchaudio.functional.resample(sample, 22050, 24000) sample = pad_or_truncate(sample, 102400) sample = sample.to(device) - mel = model.mel_spectrogram(wav) + mel = model.mel_spectrogram(sample) """ if do_normalization: mel = normalize_tacotron_mel(mel) @@ -70,6 +83,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"): dvae = load_model("dvae", device=device) unified_voice = load_model("unified_voice", device=device) diffusion = load_model("diffusion", device=device) + mel_inputs = format_autoregressive_conditioning( wav, 0, device ) wav_length = wav.shape[-1] duration = wav_length / sr @@ -78,16 +92,19 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"): diffusion_conds = torch.stack([ format_diffusion_conditioning(wav.to(device), device=device) ], dim=1) codes = dvae.get_codebook_indices( mel_inputs ) - auto_latent = unified_voice.get_conditioning(autoregressive_conds) + + autoregressive_latent = unified_voice.get_conditioning(autoregressive_conds) diffusion_latent = diffusion.get_conditioning(diffusion_conds) return { "codes": codes, "conds": (autoregressive_conds, diffusion_conds), "latent": (autoregressive_latent, diffusion_latent), - "original_length": wav_length, - "sample_rate": sr, - "duration": duration + "metadata": { + "original_length": wav_length, + "sample_rate": sr, + "duration": duration + } } def encode_from_files(paths, device="cuda"): diff --git a/tortoise_tts/models/__init__.py b/tortoise_tts/models/__init__.py index c3034b1..711b31c 100755 --- a/tortoise_tts/models/__init__.py +++ b/tortoise_tts/models/__init__.py @@ -3,7 +3,7 @@ from functools import cache -from ..arch_utils import TorchMelSpectrogram, TacotronSTFT +from .arch_utils import TorchMelSpectrogram, TacotronSTFT from .unified_voice import UnifiedVoice from .diffusion import DiffusionTTS @@ -11,26 +11,45 @@ from .vocoder import UnivNetGenerator from .clvp import CLVP from .dvae import DiscreteVAE +import os +import torch + +DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/') + # semi-necessary as a way to provide a mechanism for other portions of the program to access models @cache def load_model(name, device="cuda", **kwargs): + load_path = None if "autoregressive" in name or "unified_voice" in name: model = UnifiedVoice(**kwargs) + load_path = f'{DEFAULT_MODEL_PATH}/autoregressive.pth' elif "diffusion" in name: model = DiffusionTTS(**kwargs) + load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth' elif "clvp" in name: model = CLVP(**kwargs) + load_path = f'{DEFAULT_MODEL_PATH}/clvp2.pth' elif "vocoder" in name: model = UnivNetGenerator(**kwargs) + load_path = f'{DEFAULT_MODEL_PATH}/vocoder.pth' elif "dvae" in name: + load_path = f'{DEFAULT_MODEL_PATH}/dvae.pth' model = DiscreteVAE(**kwargs) - # to-do: figure out of the below two give the exact same output, since the AR uses #1, the Diffusion uses #2 + # to-do: figure out of the below two give the exact same output elif "stft" in name: - model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs) + sr = kwargs.pop("sr") + if sr == 24_000: + model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs) + else: + model = TacotronSTFT(**kwargs) elif "tms" in name: model = TorchMelSpectrogram(**kwargs) - + model = model.to(device=device) + + if load_path is not None: + model.load_state_dict(torch.load(load_path, map_location=device), strict=False) + return model def unload_model(): diff --git a/tortoise_tts/models/arch_utils.py b/tortoise_tts/models/arch_utils.py index 473a217..832aadb 100644 --- a/tortoise_tts/models/arch_utils.py +++ b/tortoise_tts/models/arch_utils.py @@ -289,7 +289,7 @@ class AudioMiniEncoder(nn.Module): return h[:, :, 0] -DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth') +DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/mel_norms.pth') class TorchMelSpectrogram(nn.Module): @@ -463,6 +463,34 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800, x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] return x +TACOTRON_MEL_MAX = 2.3143386840820312 +TACOTRON_MEL_MIN = -11.512925148010254 + + +def denormalize_tacotron_mel(norm_mel): + return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN + + +def normalize_tacotron_mel(mel): + return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C class STFT(torch.nn.Module): """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" @@ -566,9 +594,16 @@ class STFT(torch.nn.Module): return reconstruction class TacotronSTFT(torch.nn.Module): - def __init__(self, filter_length=1024, hop_length=256, win_length=1024, - n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, - mel_fmax=8000.0): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=80, + sampling_rate=22050, + mel_fmin=0.0, + mel_fmax=8000.0 + ): super().__init__() self.n_mel_channels = n_mel_channels self.sampling_rate = sampling_rate diff --git a/tortoise_tts/models/dvae.py b/tortoise_tts/models/dvae.py index 71e6098..f1a15d6 100644 --- a/tortoise_tts/models/dvae.py +++ b/tortoise_tts/models/dvae.py @@ -119,7 +119,7 @@ class DiscreteVAE(nn.Module): positional_dims = 1, # 2 num_tokens = 8192, # 512 codebook_dim = 512, - num_layers = 3, + num_layers = 2, # 3 num_resnet_blocks = 3, # 0 hidden_dim = 512, # 64 channels = 80, # 3 diff --git a/tortoise_tts/train.py b/tortoise_tts/train.py index 91ce94c..62bc5c6 100755 --- a/tortoise_tts/train.py +++ b/tortoise_tts/train.py @@ -2,10 +2,9 @@ from .config import cfg from .data import create_train_val_dataloader -from .emb import qnt +from .emb import mel from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc -from .data import fold_inputs, unfold_outputs from .utils.distributed import is_global_leader import auraloss @@ -34,7 +33,7 @@ def train_feeder(engine, batch): text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True) text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True) mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True) - wav_lengths = pad_sequence([ length for length in batch["wav_lengths"] ], batch_first = True) + wav_lengths = pad_sequence([ length for length in batch["wav_length"] ], batch_first = True) engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths) diff --git a/tortoise_tts/version.py b/tortoise_tts/version.py index 24ea8a3..6e7f03e 100644 --- a/tortoise_tts/version.py +++ b/tortoise_tts/version.py @@ -1 +1 @@ -__version__ = "0.0.1-dev20240602082927" +__version__ = "0.0.1-dev20240617224834"