encoding mel tokens + dataset preparation

This commit is contained in:
mrq 2024-06-18 10:30:54 -05:00
parent 37ec9f1b79
commit d7b63d2f70
14 changed files with 182 additions and 296 deletions

BIN
data/mel_norms.pth Normal file

Binary file not shown.

1
data/tokenizer.json Normal file
View File

@ -0,0 +1 @@
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}

View File

@ -1,96 +0,0 @@
import os
import json
import torch
from tqdm.auto import tqdm
from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
device = "cuda"
target = "in"
audio_map = {}
text_map = {}
data = {}
for season in os.listdir(f"./{target}/"):
if not os.path.isdir(f"./{target}/{season}/"):
continue
for episode in os.listdir(f"./{target}/{season}/"):
if not os.path.isdir(f"./{target}/{season}/{episode}/"):
continue
for filename in os.listdir(f"./{target}/{season}/{episode}/"):
path = f'./{target}/{season}/{episode}/{filename}'
attrs = filename.split("_")
timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s'
key = f'{episode}_{timestamp}'
if filename[-5:] == ".flac":
name = attrs[3]
emotion = attrs[4]
quality = attrs[5]
audio_map[key] = {
"path": path,
'episode': episode,
"name": name,
"emotion": emotion,
"quality": quality,
"timestamp": timestamp,
}
elif filename[-4:] == ".txt":
text_map[key] = open(path, encoding="utf-8").read()
txts = {}
wavs = []
for key, entry in audio_map.items():
path = entry['path']
name = entry['name']
emotion = entry['emotion']
quality = entry['quality']
episode = entry['episode']
path = entry['path']
timestamp = entry['timestamp']
transcription = text_map[key]
if name not in data:
data[name] = {}
os.makedirs(f'./training/{name}/', exist_ok=True)
os.makedirs(f'./voices/{name}/', exist_ok=True)
key = f'{episode}_{timestamp}.flac'
os.rename(path, f'./voices/{name}/{key}')
data[name][key] = {
"segments": [],
"language": "en",
"text": transcription,
"misc": {
"emotion": emotion,
"quality": quality,
"timestamp": timestamp,
"episode": episode,
}
}
path = f'./voices/{name}/{key}'
txts[path] = transcription
wavs.append(Path(path))
for name in data.keys():
open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) )
for key, text in tqdm(txts.items(), desc="Phonemizing..."):
path = Path(key)
phones = valle_phonemize(text)
open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
for path in tqdm(wavs, desc="Quantizing..."):
qnt = valle_quantize(path, device=device)
torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt"))

View File

@ -5,7 +5,7 @@ import torchaudio
import numpy as np import numpy as np
from tqdm.auto import tqdm from tqdm.auto import tqdm
from pathlib import Path from pathlib import Path
from vall_e.config import cfg from tortoise_tts.config import cfg
# things that could be args # things that could be args
cfg.sample_rate = 24_000 cfg.sample_rate = 24_000
@ -16,15 +16,14 @@ cfg.inference.dtype = torch.bfloat16
cfg.inference.amp = True cfg.inference.amp = True
""" """
from vall_e.emb.g2p import encode as valle_phonemize from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
input_audio = "voices" input_audio = "voices"
input_metadata = "metadata" input_metadata = "metadata"
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}" output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}"
device = "cuda" device = "cuda"
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc" audio_extension = ".mel"
slice = "auto" slice = "auto"
missing = { missing = {
@ -57,30 +56,8 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
continue continue
waveform, sample_rate = torchaudio.load(inpath) waveform, sample_rate = torchaudio.load(inpath)
qnt = valle_quantize(waveform, sr=sample_rate, device=device) mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device)
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel )
if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
},
})
continue continue
@ -177,45 +154,16 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
)) ))
if len(wavs) > 0: if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): for job in tqdm(wavs, desc=f"Encoding: {speaker_id}"):
try: try:
outpath, text, language, waveform, sample_rate = job outpath, text, language, waveform, sample_rate = job
phones = valle_phonemize(text) phones = valle_phonemize(text)
qnt = valle_quantize(waveform, sr=sample_rate, device=device) mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device)
mel["text"] = text
if cfg.inference.audio_backend == "dac": np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel)
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
except Exception as e: except Exception as e:
print(f"Failed to quantize: {outpath}:", e) print(f"Failed to encode: {outpath}:", e)
continue continue
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing))

View File

@ -18,9 +18,9 @@ cfg.inference.amp = True
""" """
from vall_e.emb.g2p import encode as valle_phonemize from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc" audio_extension = ".mel"
input_dataset = "LibriTTS_R" input_dataset = "LibriTTS_R"
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz" output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
@ -56,52 +56,10 @@ for dataset_name in os.listdir(f'./{input_dataset}/'):
for paths in tqdm(txts, desc="Processing..."): for paths in tqdm(txts, desc="Processing..."):
inpath, outpath = paths inpath, outpath = paths
try: try:
if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists(): text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read())
qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True)
if not isinstance(data["phonemes"], str): mel = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
data["phonemes"] = "".join(data["phonemes"]) mel["text"] = text
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel)
for k, v in data.items():
qnt[()]['metadata'][k] = v
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt)
else:
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
phones = valle_phonemize(text)
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": "en",
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.shape[-1] / 75.0,
"sample_rate": cfg.sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": "en",
},
})
except Exception as e: except Exception as e:
tqdm.write(f"Failed to process: {paths}: {e}") tqdm.write(f"Failed to process: {paths}: {e}")

View File

@ -54,6 +54,9 @@ setup(
"tokenizers", "tokenizers",
"transformers", "transformers",
#
"rotary_embedding_torch",
# training bloat # training bloat
"auraloss[all]", # [all] is needed for MelSTFTLoss "auraloss[all]", # [all] is needed for MelSTFTLoss
"h5py", "h5py",

View File

@ -502,7 +502,7 @@ class Optimizations:
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation) optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
bitsandbytes: bool = False # use bitsandbytes bitsandbytes: bool = False # use bitsandbytes
dadaptation: bool = True # use dadaptation optimizer dadaptation: bool = False # use dadaptation optimizer
bitnet: bool = False # use bitnet bitnet: bool = False # use bitnet
fp8: bool = False # use fp8 fp8: bool = False # use fp8
@ -525,6 +525,7 @@ class Config(BaseConfig):
tokenizer: str = "./tokenizer.json" tokenizer: str = "./tokenizer.json"
sample_rate: int = 24_000 sample_rate: int = 24_000
audio_backend: str = "mel"
@property @property
def model(self): def model(self):

View File

@ -169,8 +169,10 @@ def _get_paths_of_extensions( path, extensions=_get_mel_extension(), validate=Fa
def _load_mels(path, return_metadata=False) -> Tensor: def _load_mels(path, return_metadata=False) -> Tensor:
mel = np.load(_get_mel_path(path), allow_pickle=True)[()] mel = np.load(_get_mel_path(path), allow_pickle=True)[()]
if return_metadata: if return_metadata:
return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16), mel["metadata"] mel["metadata"]["text"] = mel["text"]
return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16)
return mel["codes"].to(torch.int16), mel["metadata"]
return mel["codes"].to(torch.int16)
# prune consecutive spaces # prune consecutive spaces
def _cleanup_phones( phones, targets=[" "]): def _cleanup_phones( phones, targets=[" "]):
@ -453,37 +455,19 @@ class Dataset(_Dataset):
) )
""" """
prom_length = 0 path = random.choice(choices)
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
for _ in range(cfg.dataset.max_prompts): if "audio" not in cfg.hdf5[key]:
path = random.choice(choices) _logger.warning(f'MISSING AUDIO: {key}')
if cfg.dataset.use_hdf5: return
key = _get_hdf5_path(path)
if "audio" not in cfg.hdf5[key]: # audio / cond / latents
_logger.warning(f'MISSING AUDIO: {key}') # parameter names and documentation are weird
continue prom = torch.from_numpy(cfg.hdf5[key]["cond"]).to(torch.int16)
else:
mel = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16) prom = _load_mels(path, return_metadata=False)
else:
mel = _load_mels(path, return_metadata=False)
if 0 < trim_length and trim_length < mel.shape[0]:
mel = trim( mel, trim_length )
prom_list.append(mel)
prom_length += mel.shape[0]
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
break
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
prom = torch.cat(prom_list)
if 0 < trim_length and trim_length < prom.shape[0]:
prom = trim( prom, trim_length )
return prom return prom
@ -507,15 +491,36 @@ class Dataset(_Dataset):
spkr_group = self.get_speaker_group(path) spkr_group = self.get_speaker_group(path)
#spkr_group_id = self.spkr_group_symmap[spkr_group] #spkr_group_id = self.spkr_group_symmap[spkr_group]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
text = cfg.hdf5[key]["text"][:]
mel = cfg.hdf5[key]["audio"][:]
latents = cfg.hdf5[key]["latents"][:]
text = torch.from_numpy(text).to(self.text_dtype)
mel = torch.from_numpy(mel).to(torch.int16)
latents = torch.from_numpy(latents)
wav_length = cfg.hdf5[key].attrs["wav_length"]
else:
mel, metadata = _load_mels(path, return_metadata=True)
text = torch.tensor(metadata["text"]).to(self.text_dtype)
latents = torch.from_numpy(metadata["latent"][0])
wav_length = metadata["wav_length"]
return dict( return dict(
index=index, index=index,
path=Path(path), path=Path(path),
spkr_name=spkr_name, spkr_name=spkr_name,
spkr_id=spkr_id, spkr_id=spkr_id,
#text=text,
#proms=proms, latents=latents,
#resps=resps, text=text,
mel=mel,
wav_length=wav_length,
) )
def head_(self, n): def head_(self, n):
@ -603,12 +608,14 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl return train_dl, subtrain_dl, val_dl
def unpack_audio( npz ): def unpack_audio( npz ):
mel = torch.from_numpy(npz["codes"].astype(int))[0].t().to(dtype=torch.int16) mel = npz["codes"].to(dtype=torch.int16, device="cpu")
conds = npz["conds"][0].to(dtype=torch.int16, device="cpu")
latent = npz["latent"][0].to(dtype=torch.int16, device="cpu")
metadata = {} metadata = {}
if "text" in npz["metadata"]: if "text" in npz:
metadata["text"] = npz["metadata"]["text"] metadata["text"] = npz["text"]
if "phonemes" in npz["metadata"]: if "phonemes" in npz["metadata"]:
metadata["phonemes"] = npz["metadata"]["phonemes"] metadata["phonemes"] = npz["metadata"]["phonemes"]
@ -616,10 +623,15 @@ def unpack_audio( npz ):
if "language" in npz["metadata"]: if "language" in npz["metadata"]:
metadata["language"] = npz["metadata"]["language"] metadata["language"] = npz["metadata"]["language"]
if "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]: if "original_length" in npz["metadata"]:
metadata["wav_length"] = npz["metadata"]["original_length"]
if "duration" in npz["metadata"]:
metadata["duration"] = npz["metadata"]["duration"]
elif "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]:
metadata["duration"] = npz["metadata"]["original_length"] / npz["metadata"]["sample_rate"] metadata["duration"] = npz["metadata"]["original_length"] / npz["metadata"]["sample_rate"]
return mel, metadata return mel, conds, latent, metadata
# parse dataset into better to sample metadata # parse dataset into better to sample metadata
def create_dataset_metadata( skip_existing=True ): def create_dataset_metadata( skip_existing=True ):
@ -672,7 +684,7 @@ def create_dataset_metadata( skip_existing=True ):
if audios: if audios:
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt # ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()] npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
mel, utterance_metadata = unpack_audio( npz ) mel, conds, latents, utterance_metadata = unpack_audio( npz )
# text # text
if texts and text_exists and not utterance_metadata: if texts and text_exists and not utterance_metadata:
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read()) utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
@ -755,11 +767,17 @@ def create_dataset_hdf5( skip_existing=True ):
# audio # audio
if audios: if audios:
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()] npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
mel, utterance_metadata = unpack_audio( npz ) mel, conds, latents, utterance_metadata = unpack_audio( npz )
if "audio" not in group: if "audio" not in group:
group.create_dataset('audio', data=mel.numpy().astype(np.int16), compression='lzf') group.create_dataset('audio', data=mel.numpy().astype(np.int16), compression='lzf')
if "conds" not in group:
group.create_dataset('conds', data=conds.numpy().astype(np.int16), compression='lzf')
if "latents" not in group:
group.create_dataset('latents', data=latents.numpy().astype(np.int16), compression='lzf')
# text # text
if texts: if texts:
if not utterance_metadata and text_exists: if not utterance_metadata and text_exists:
@ -778,6 +796,7 @@ def create_dataset_hdf5( skip_existing=True ):
except Exception as e: except Exception as e:
tqdm.write(f'Error while processing {id}: {e}') tqdm.write(f'Error while processing {id}: {e}')
raise e
with open(str(metadata_path), "w", encoding="utf-8") as f: with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) ) f.write( json.dumps( metadata ) )
@ -837,28 +856,10 @@ if __name__ == "__main__":
samples = { samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ], "training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ], #"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ], #"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
} }
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
for k, v in samples.items():
for i in range(len(v)):
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
"""
try:
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
try:
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
"""
v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape
for k, v in samples.items(): for k, v in samples.items():
for i in range(len(v)): for i in range(len(v)):
print(f'{k}[{i}]:', v[i]) print(f'{k}[{i}]:', v[i])

View File

@ -15,6 +15,19 @@ from tqdm import tqdm
from ..models import load_model, unload_model from ..models import load_model, unload_model
import torch.nn.functional as F
def pad_or_truncate(t, length):
"""
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
"""
if t.shape[-1] == length:
return t
elif t.shape[-1] < length:
return F.pad(t, (0, length-t.shape[-1]))
else:
return t[..., :length]
# decodes mel spectrogram into a wav # decodes mel spectrogram into a wav
@torch.inference_mode() @torch.inference_mode()
def decode(codes: Tensor, device="cuda"): def decode(codes: Tensor, device="cuda"):
@ -34,30 +47,30 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
def _replace_file_extension(path, suffix): def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix) return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def format_autoregressive_conditioning( wav, cond_length=132300, device ): def format_autoregressive_conditioning( wav, cond_length=132300, device="cuda" ):
""" """
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models. Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
""" """
model = load_model("tms", device=device) model = load_model("tms", device=device)
gap = wav.shape[-1] - cond_length if cond_length > 0:
gap = wav.shape[-1] - cond_length
if gap < 0: if gap < 0:
wav = F.pad(wav, pad=(0, abs(gap))) wav = F.pad(wav, pad=(0, abs(gap)))
elif gap > 0: elif gap > 0:
rand_start = random.randint(0, gap) rand_start = random.randint(0, gap)
wav = wav[:, rand_start:rand_start + cond_length] wav = wav[:, rand_start:rand_start + cond_length]
mel_clip = model(wav.unsqueeze(0)).squeeze(0) # ??? mel_clip = model(wav.unsqueeze(0)).squeeze(0) # ???
return mel_clip.unsqueeze(0).to(device) # ??? return mel_clip.unsqueeze(0).to(device) # ???
def format_diffusion_conditioning( sample, device, do_normalization=False ): def format_diffusion_conditioning( sample, device, do_normalization=False ):
model = load_model("stft", device=device) model = load_model("stft", device=device, sr=24_000)
sample = torchaudio.functional.resample(sample, 22050, 24000) sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400) sample = pad_or_truncate(sample, 102400)
sample = sample.to(device) sample = sample.to(device)
mel = model.mel_spectrogram(wav) mel = model.mel_spectrogram(sample)
""" """
if do_normalization: if do_normalization:
mel = normalize_tacotron_mel(mel) mel = normalize_tacotron_mel(mel)
@ -70,6 +83,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"):
dvae = load_model("dvae", device=device) dvae = load_model("dvae", device=device)
unified_voice = load_model("unified_voice", device=device) unified_voice = load_model("unified_voice", device=device)
diffusion = load_model("diffusion", device=device) diffusion = load_model("diffusion", device=device)
mel_inputs = format_autoregressive_conditioning( wav, 0, device )
wav_length = wav.shape[-1] wav_length = wav.shape[-1]
duration = wav_length / sr duration = wav_length / sr
@ -78,16 +92,19 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"):
diffusion_conds = torch.stack([ format_diffusion_conditioning(wav.to(device), device=device) ], dim=1) diffusion_conds = torch.stack([ format_diffusion_conditioning(wav.to(device), device=device) ], dim=1)
codes = dvae.get_codebook_indices( mel_inputs ) codes = dvae.get_codebook_indices( mel_inputs )
auto_latent = unified_voice.get_conditioning(autoregressive_conds)
autoregressive_latent = unified_voice.get_conditioning(autoregressive_conds)
diffusion_latent = diffusion.get_conditioning(diffusion_conds) diffusion_latent = diffusion.get_conditioning(diffusion_conds)
return { return {
"codes": codes, "codes": codes,
"conds": (autoregressive_conds, diffusion_conds), "conds": (autoregressive_conds, diffusion_conds),
"latent": (autoregressive_latent, diffusion_latent), "latent": (autoregressive_latent, diffusion_latent),
"original_length": wav_length, "metadata": {
"sample_rate": sr, "original_length": wav_length,
"duration": duration "sample_rate": sr,
"duration": duration
}
} }
def encode_from_files(paths, device="cuda"): def encode_from_files(paths, device="cuda"):

View File

@ -3,7 +3,7 @@
from functools import cache from functools import cache
from ..arch_utils import TorchMelSpectrogram, TacotronSTFT from .arch_utils import TorchMelSpectrogram, TacotronSTFT
from .unified_voice import UnifiedVoice from .unified_voice import UnifiedVoice
from .diffusion import DiffusionTTS from .diffusion import DiffusionTTS
@ -11,26 +11,45 @@ from .vocoder import UnivNetGenerator
from .clvp import CLVP from .clvp import CLVP
from .dvae import DiscreteVAE from .dvae import DiscreteVAE
import os
import torch
DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/')
# semi-necessary as a way to provide a mechanism for other portions of the program to access models # semi-necessary as a way to provide a mechanism for other portions of the program to access models
@cache @cache
def load_model(name, device="cuda", **kwargs): def load_model(name, device="cuda", **kwargs):
load_path = None
if "autoregressive" in name or "unified_voice" in name: if "autoregressive" in name or "unified_voice" in name:
model = UnifiedVoice(**kwargs) model = UnifiedVoice(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/autoregressive.pth'
elif "diffusion" in name: elif "diffusion" in name:
model = DiffusionTTS(**kwargs) model = DiffusionTTS(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth'
elif "clvp" in name: elif "clvp" in name:
model = CLVP(**kwargs) model = CLVP(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/clvp2.pth'
elif "vocoder" in name: elif "vocoder" in name:
model = UnivNetGenerator(**kwargs) model = UnivNetGenerator(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/vocoder.pth'
elif "dvae" in name: elif "dvae" in name:
load_path = f'{DEFAULT_MODEL_PATH}/dvae.pth'
model = DiscreteVAE(**kwargs) model = DiscreteVAE(**kwargs)
# to-do: figure out of the below two give the exact same output, since the AR uses #1, the Diffusion uses #2 # to-do: figure out of the below two give the exact same output
elif "stft" in name: elif "stft" in name:
model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs) sr = kwargs.pop("sr")
if sr == 24_000:
model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs)
else:
model = TacotronSTFT(**kwargs)
elif "tms" in name: elif "tms" in name:
model = TorchMelSpectrogram(**kwargs) model = TorchMelSpectrogram(**kwargs)
model = model.to(device=device) model = model.to(device=device)
if load_path is not None:
model.load_state_dict(torch.load(load_path, map_location=device), strict=False)
return model return model
def unload_model(): def unload_model():

View File

@ -289,7 +289,7 @@ class AudioMiniEncoder(nn.Module):
return h[:, :, 0] return h[:, :, 0]
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth') DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/mel_norms.pth')
class TorchMelSpectrogram(nn.Module): class TorchMelSpectrogram(nn.Module):
@ -463,6 +463,34 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x return x
TACOTRON_MEL_MAX = 2.3143386840820312
TACOTRON_MEL_MIN = -11.512925148010254
def denormalize_tacotron_mel(norm_mel):
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
def normalize_tacotron_mel(mel):
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
class STFT(torch.nn.Module): class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
@ -566,9 +594,16 @@ class STFT(torch.nn.Module):
return reconstruction return reconstruction
class TacotronSTFT(torch.nn.Module): class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024, def __init__(
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, self,
mel_fmax=8000.0): filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=80,
sampling_rate=22050,
mel_fmin=0.0,
mel_fmax=8000.0
):
super().__init__() super().__init__()
self.n_mel_channels = n_mel_channels self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate

View File

@ -119,7 +119,7 @@ class DiscreteVAE(nn.Module):
positional_dims = 1, # 2 positional_dims = 1, # 2
num_tokens = 8192, # 512 num_tokens = 8192, # 512
codebook_dim = 512, codebook_dim = 512,
num_layers = 3, num_layers = 2, # 3
num_resnet_blocks = 3, # 0 num_resnet_blocks = 3, # 0
hidden_dim = 512, # 64 hidden_dim = 512, # 64
channels = 80, # 3 channels = 80, # 3

View File

@ -2,10 +2,9 @@
from .config import cfg from .config import cfg
from .data import create_train_val_dataloader from .data import create_train_val_dataloader
from .emb import qnt from .emb import mel
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .data import fold_inputs, unfold_outputs
from .utils.distributed import is_global_leader from .utils.distributed import is_global_leader
import auraloss import auraloss
@ -34,7 +33,7 @@ def train_feeder(engine, batch):
text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True) text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True) text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True)
mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True) mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True)
wav_lengths = pad_sequence([ length for length in batch["wav_lengths"] ], batch_first = True) wav_lengths = pad_sequence([ length for length in batch["wav_length"] ], batch_first = True)
engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths) engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths)

View File

@ -1 +1 @@
__version__ = "0.0.1-dev20240602082927" __version__ = "0.0.1-dev20240617224834"