encoding mel tokens + dataset preparation
This commit is contained in:
parent
37ec9f1b79
commit
d7b63d2f70
BIN
data/mel_norms.pth
Normal file
BIN
data/mel_norms.pth
Normal file
Binary file not shown.
1
data/tokenizer.json
Normal file
1
data/tokenizer.json
Normal file
|
@ -0,0 +1 @@
|
|||
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
|
|
@ -1,96 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
from vall_e.emb.g2p import encode as valle_phonemize
|
||||
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
||||
|
||||
device = "cuda"
|
||||
|
||||
target = "in"
|
||||
|
||||
audio_map = {}
|
||||
text_map = {}
|
||||
|
||||
data = {}
|
||||
|
||||
for season in os.listdir(f"./{target}/"):
|
||||
if not os.path.isdir(f"./{target}/{season}/"):
|
||||
continue
|
||||
|
||||
for episode in os.listdir(f"./{target}/{season}/"):
|
||||
if not os.path.isdir(f"./{target}/{season}/{episode}/"):
|
||||
continue
|
||||
|
||||
for filename in os.listdir(f"./{target}/{season}/{episode}/"):
|
||||
path = f'./{target}/{season}/{episode}/{filename}'
|
||||
attrs = filename.split("_")
|
||||
timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s'
|
||||
|
||||
key = f'{episode}_{timestamp}'
|
||||
|
||||
if filename[-5:] == ".flac":
|
||||
name = attrs[3]
|
||||
emotion = attrs[4]
|
||||
quality = attrs[5]
|
||||
|
||||
audio_map[key] = {
|
||||
"path": path,
|
||||
'episode': episode,
|
||||
"name": name,
|
||||
"emotion": emotion,
|
||||
"quality": quality,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
elif filename[-4:] == ".txt":
|
||||
text_map[key] = open(path, encoding="utf-8").read()
|
||||
txts = {}
|
||||
wavs = []
|
||||
|
||||
for key, entry in audio_map.items():
|
||||
path = entry['path']
|
||||
name = entry['name']
|
||||
emotion = entry['emotion']
|
||||
quality = entry['quality']
|
||||
episode = entry['episode']
|
||||
path = entry['path']
|
||||
timestamp = entry['timestamp']
|
||||
transcription = text_map[key]
|
||||
if name not in data:
|
||||
data[name] = {}
|
||||
os.makedirs(f'./training/{name}/', exist_ok=True)
|
||||
os.makedirs(f'./voices/{name}/', exist_ok=True)
|
||||
|
||||
key = f'{episode}_{timestamp}.flac'
|
||||
os.rename(path, f'./voices/{name}/{key}')
|
||||
|
||||
data[name][key] = {
|
||||
"segments": [],
|
||||
"language": "en",
|
||||
"text": transcription,
|
||||
"misc": {
|
||||
"emotion": emotion,
|
||||
"quality": quality,
|
||||
"timestamp": timestamp,
|
||||
"episode": episode,
|
||||
}
|
||||
}
|
||||
|
||||
path = f'./voices/{name}/{key}'
|
||||
txts[path] = transcription
|
||||
wavs.append(Path(path))
|
||||
|
||||
for name in data.keys():
|
||||
open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) )
|
||||
|
||||
for key, text in tqdm(txts.items(), desc="Phonemizing..."):
|
||||
path = Path(key)
|
||||
phones = valle_phonemize(text)
|
||||
open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
|
||||
|
||||
for path in tqdm(wavs, desc="Quantizing..."):
|
||||
qnt = valle_quantize(path, device=device)
|
||||
torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt"))
|
|
@ -5,7 +5,7 @@ import torchaudio
|
|||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
from vall_e.config import cfg
|
||||
from tortoise_tts.config import cfg
|
||||
|
||||
# things that could be args
|
||||
cfg.sample_rate = 24_000
|
||||
|
@ -16,15 +16,14 @@ cfg.inference.dtype = torch.bfloat16
|
|||
cfg.inference.amp = True
|
||||
"""
|
||||
|
||||
from vall_e.emb.g2p import encode as valle_phonemize
|
||||
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
|
||||
from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension
|
||||
|
||||
input_audio = "voices"
|
||||
input_metadata = "metadata"
|
||||
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}"
|
||||
device = "cuda"
|
||||
|
||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
||||
audio_extension = ".mel"
|
||||
|
||||
slice = "auto"
|
||||
missing = {
|
||||
|
@ -57,30 +56,8 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
|||
continue
|
||||
|
||||
waveform, sample_rate = torchaudio.load(inpath)
|
||||
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
},
|
||||
})
|
||||
mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device)
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel )
|
||||
|
||||
continue
|
||||
|
||||
|
@ -177,45 +154,16 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
|||
))
|
||||
|
||||
if len(wavs) > 0:
|
||||
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
|
||||
for job in tqdm(wavs, desc=f"Encoding: {speaker_id}"):
|
||||
try:
|
||||
outpath, text, language, waveform, sample_rate = job
|
||||
|
||||
phones = valle_phonemize(text)
|
||||
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
||||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": waveform.shape[-1],
|
||||
"sample_rate": sample_rate,
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": language,
|
||||
},
|
||||
})
|
||||
mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device)
|
||||
mel["text"] = text
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel)
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize: {outpath}:", e)
|
||||
print(f"Failed to encode: {outpath}:", e)
|
||||
continue
|
||||
|
||||
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||
|
|
|
@ -18,9 +18,9 @@ cfg.inference.amp = True
|
|||
"""
|
||||
|
||||
from vall_e.emb.g2p import encode as valle_phonemize
|
||||
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
||||
from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension
|
||||
|
||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
||||
audio_extension = ".mel"
|
||||
|
||||
input_dataset = "LibriTTS_R"
|
||||
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
|
||||
|
@ -56,52 +56,10 @@ for dataset_name in os.listdir(f'./{input_dataset}/'):
|
|||
for paths in tqdm(txts, desc="Processing..."):
|
||||
inpath, outpath = paths
|
||||
try:
|
||||
if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists():
|
||||
data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read())
|
||||
qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True)
|
||||
|
||||
if not isinstance(data["phonemes"], str):
|
||||
data["phonemes"] = "".join(data["phonemes"])
|
||||
|
||||
for k, v in data.items():
|
||||
qnt[()]['metadata'][k] = v
|
||||
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt)
|
||||
else:
|
||||
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
|
||||
|
||||
phones = valle_phonemize(text)
|
||||
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
|
||||
|
||||
if cfg.inference.audio_backend == "dac":
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.original_length,
|
||||
"sample_rate": qnt.sample_rate,
|
||||
|
||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
||||
"chunk_length": qnt.chunk_length,
|
||||
"channels": qnt.channels,
|
||||
"padding": qnt.padding,
|
||||
"dac_version": "1.0.0",
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": "en",
|
||||
},
|
||||
})
|
||||
else:
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
||||
"metadata": {
|
||||
"original_length": qnt.shape[-1] / 75.0,
|
||||
"sample_rate": cfg.sample_rate,
|
||||
|
||||
"text": text.strip(),
|
||||
"phonemes": "".join(phones),
|
||||
"language": "en",
|
||||
},
|
||||
})
|
||||
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
|
||||
|
||||
mel = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
|
||||
mel["text"] = text
|
||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel)
|
||||
except Exception as e:
|
||||
tqdm.write(f"Failed to process: {paths}: {e}")
|
||||
|
|
3
setup.py
3
setup.py
|
@ -54,6 +54,9 @@ setup(
|
|||
"tokenizers",
|
||||
"transformers",
|
||||
|
||||
#
|
||||
"rotary_embedding_torch",
|
||||
|
||||
# training bloat
|
||||
"auraloss[all]", # [all] is needed for MelSTFTLoss
|
||||
"h5py",
|
||||
|
|
|
@ -502,7 +502,7 @@ class Optimizations:
|
|||
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
||||
|
||||
bitsandbytes: bool = False # use bitsandbytes
|
||||
dadaptation: bool = True # use dadaptation optimizer
|
||||
dadaptation: bool = False # use dadaptation optimizer
|
||||
bitnet: bool = False # use bitnet
|
||||
fp8: bool = False # use fp8
|
||||
|
||||
|
@ -525,6 +525,7 @@ class Config(BaseConfig):
|
|||
tokenizer: str = "./tokenizer.json"
|
||||
|
||||
sample_rate: int = 24_000
|
||||
audio_backend: str = "mel"
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
|
|
|
@ -169,8 +169,10 @@ def _get_paths_of_extensions( path, extensions=_get_mel_extension(), validate=Fa
|
|||
def _load_mels(path, return_metadata=False) -> Tensor:
|
||||
mel = np.load(_get_mel_path(path), allow_pickle=True)[()]
|
||||
if return_metadata:
|
||||
return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16), mel["metadata"]
|
||||
return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16)
|
||||
mel["metadata"]["text"] = mel["text"]
|
||||
|
||||
return mel["codes"].to(torch.int16), mel["metadata"]
|
||||
return mel["codes"].to(torch.int16)
|
||||
|
||||
# prune consecutive spaces
|
||||
def _cleanup_phones( phones, targets=[" "]):
|
||||
|
@ -453,37 +455,19 @@ class Dataset(_Dataset):
|
|||
)
|
||||
"""
|
||||
|
||||
prom_length = 0
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
path = random.choice(choices)
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
if "audio" not in cfg.hdf5[key]:
|
||||
_logger.warning(f'MISSING AUDIO: {key}')
|
||||
return
|
||||
|
||||
if "audio" not in cfg.hdf5[key]:
|
||||
_logger.warning(f'MISSING AUDIO: {key}')
|
||||
continue
|
||||
|
||||
mel = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
|
||||
else:
|
||||
mel = _load_mels(path, return_metadata=False)
|
||||
|
||||
if 0 < trim_length and trim_length < mel.shape[0]:
|
||||
mel = trim( mel, trim_length )
|
||||
|
||||
prom_list.append(mel)
|
||||
prom_length += mel.shape[0]
|
||||
|
||||
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
|
||||
break
|
||||
|
||||
# might be better to decode => concat waveforms with silence in between => reencode
|
||||
# as you technically can't just append encodec sequences together like this without issues
|
||||
prom = torch.cat(prom_list)
|
||||
|
||||
if 0 < trim_length and trim_length < prom.shape[0]:
|
||||
prom = trim( prom, trim_length )
|
||||
# audio / cond / latents
|
||||
# parameter names and documentation are weird
|
||||
prom = torch.from_numpy(cfg.hdf5[key]["cond"]).to(torch.int16)
|
||||
else:
|
||||
prom = _load_mels(path, return_metadata=False)
|
||||
|
||||
return prom
|
||||
|
||||
|
@ -507,15 +491,36 @@ class Dataset(_Dataset):
|
|||
spkr_group = self.get_speaker_group(path)
|
||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||
|
||||
if cfg.dataset.use_hdf5:
|
||||
key = _get_hdf5_path(path)
|
||||
|
||||
if key not in cfg.hdf5:
|
||||
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
||||
|
||||
text = cfg.hdf5[key]["text"][:]
|
||||
mel = cfg.hdf5[key]["audio"][:]
|
||||
latents = cfg.hdf5[key]["latents"][:]
|
||||
|
||||
text = torch.from_numpy(text).to(self.text_dtype)
|
||||
mel = torch.from_numpy(mel).to(torch.int16)
|
||||
latents = torch.from_numpy(latents)
|
||||
wav_length = cfg.hdf5[key].attrs["wav_length"]
|
||||
else:
|
||||
mel, metadata = _load_mels(path, return_metadata=True)
|
||||
text = torch.tensor(metadata["text"]).to(self.text_dtype)
|
||||
latents = torch.from_numpy(metadata["latent"][0])
|
||||
wav_length = metadata["wav_length"]
|
||||
|
||||
return dict(
|
||||
index=index,
|
||||
path=Path(path),
|
||||
spkr_name=spkr_name,
|
||||
spkr_id=spkr_id,
|
||||
#text=text,
|
||||
#proms=proms,
|
||||
#resps=resps,
|
||||
|
||||
latents=latents,
|
||||
text=text,
|
||||
mel=mel,
|
||||
wav_length=wav_length,
|
||||
)
|
||||
|
||||
def head_(self, n):
|
||||
|
@ -603,12 +608,14 @@ def create_train_val_dataloader():
|
|||
return train_dl, subtrain_dl, val_dl
|
||||
|
||||
def unpack_audio( npz ):
|
||||
mel = torch.from_numpy(npz["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
||||
mel = npz["codes"].to(dtype=torch.int16, device="cpu")
|
||||
conds = npz["conds"][0].to(dtype=torch.int16, device="cpu")
|
||||
latent = npz["latent"][0].to(dtype=torch.int16, device="cpu")
|
||||
|
||||
metadata = {}
|
||||
|
||||
if "text" in npz["metadata"]:
|
||||
metadata["text"] = npz["metadata"]["text"]
|
||||
if "text" in npz:
|
||||
metadata["text"] = npz["text"]
|
||||
|
||||
if "phonemes" in npz["metadata"]:
|
||||
metadata["phonemes"] = npz["metadata"]["phonemes"]
|
||||
|
@ -616,10 +623,15 @@ def unpack_audio( npz ):
|
|||
if "language" in npz["metadata"]:
|
||||
metadata["language"] = npz["metadata"]["language"]
|
||||
|
||||
if "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]:
|
||||
if "original_length" in npz["metadata"]:
|
||||
metadata["wav_length"] = npz["metadata"]["original_length"]
|
||||
|
||||
if "duration" in npz["metadata"]:
|
||||
metadata["duration"] = npz["metadata"]["duration"]
|
||||
elif "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]:
|
||||
metadata["duration"] = npz["metadata"]["original_length"] / npz["metadata"]["sample_rate"]
|
||||
|
||||
return mel, metadata
|
||||
return mel, conds, latent, metadata
|
||||
|
||||
# parse dataset into better to sample metadata
|
||||
def create_dataset_metadata( skip_existing=True ):
|
||||
|
@ -672,7 +684,7 @@ def create_dataset_metadata( skip_existing=True ):
|
|||
if audios:
|
||||
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
|
||||
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
|
||||
mel, utterance_metadata = unpack_audio( npz )
|
||||
mel, conds, latents, utterance_metadata = unpack_audio( npz )
|
||||
# text
|
||||
if texts and text_exists and not utterance_metadata:
|
||||
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||
|
@ -755,10 +767,16 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
# audio
|
||||
if audios:
|
||||
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
|
||||
mel, utterance_metadata = unpack_audio( npz )
|
||||
mel, conds, latents, utterance_metadata = unpack_audio( npz )
|
||||
|
||||
if "audio" not in group:
|
||||
group.create_dataset('audio', data=mel.numpy().astype(np.int16), compression='lzf')
|
||||
|
||||
if "conds" not in group:
|
||||
group.create_dataset('conds', data=conds.numpy().astype(np.int16), compression='lzf')
|
||||
|
||||
if "latents" not in group:
|
||||
group.create_dataset('latents', data=latents.numpy().astype(np.int16), compression='lzf')
|
||||
|
||||
# text
|
||||
if texts:
|
||||
|
@ -778,6 +796,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
|||
|
||||
except Exception as e:
|
||||
tqdm.write(f'Error while processing {id}: {e}')
|
||||
raise e
|
||||
|
||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||
f.write( json.dumps( metadata ) )
|
||||
|
@ -837,27 +856,9 @@ if __name__ == "__main__":
|
|||
|
||||
samples = {
|
||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||
}
|
||||
|
||||
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
|
||||
"""
|
||||
try:
|
||||
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
||||
except Exception as e:
|
||||
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
|
||||
try:
|
||||
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
||||
except Exception as e:
|
||||
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
|
||||
"""
|
||||
v[i]['proms'][j] = v[i]['proms'][j].shape
|
||||
v[i]['resps'][j] = v[i]['resps'][j].shape
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
|
|
|
@ -15,6 +15,19 @@ from tqdm import tqdm
|
|||
|
||||
from ..models import load_model, unload_model
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
"""
|
||||
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
|
||||
"""
|
||||
if t.shape[-1] == length:
|
||||
return t
|
||||
elif t.shape[-1] < length:
|
||||
return F.pad(t, (0, length-t.shape[-1]))
|
||||
else:
|
||||
return t[..., :length]
|
||||
|
||||
# decodes mel spectrogram into a wav
|
||||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda"):
|
||||
|
@ -34,30 +47,30 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
|
|||
def _replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
def format_autoregressive_conditioning( wav, cond_length=132300, device ):
|
||||
def format_autoregressive_conditioning( wav, cond_length=132300, device="cuda" ):
|
||||
"""
|
||||
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
|
||||
"""
|
||||
model = load_model("tms", device=device)
|
||||
|
||||
gap = wav.shape[-1] - cond_length
|
||||
|
||||
if gap < 0:
|
||||
wav = F.pad(wav, pad=(0, abs(gap)))
|
||||
elif gap > 0:
|
||||
rand_start = random.randint(0, gap)
|
||||
wav = wav[:, rand_start:rand_start + cond_length]
|
||||
if cond_length > 0:
|
||||
gap = wav.shape[-1] - cond_length
|
||||
if gap < 0:
|
||||
wav = F.pad(wav, pad=(0, abs(gap)))
|
||||
elif gap > 0:
|
||||
rand_start = random.randint(0, gap)
|
||||
wav = wav[:, rand_start:rand_start + cond_length]
|
||||
|
||||
mel_clip = model(wav.unsqueeze(0)).squeeze(0) # ???
|
||||
return mel_clip.unsqueeze(0).to(device) # ???
|
||||
|
||||
def format_diffusion_conditioning( sample, device, do_normalization=False ):
|
||||
model = load_model("stft", device=device)
|
||||
model = load_model("stft", device=device, sr=24_000)
|
||||
|
||||
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
||||
sample = pad_or_truncate(sample, 102400)
|
||||
sample = sample.to(device)
|
||||
mel = model.mel_spectrogram(wav)
|
||||
mel = model.mel_spectrogram(sample)
|
||||
"""
|
||||
if do_normalization:
|
||||
mel = normalize_tacotron_mel(mel)
|
||||
|
@ -70,6 +83,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"):
|
|||
dvae = load_model("dvae", device=device)
|
||||
unified_voice = load_model("unified_voice", device=device)
|
||||
diffusion = load_model("diffusion", device=device)
|
||||
mel_inputs = format_autoregressive_conditioning( wav, 0, device )
|
||||
|
||||
wav_length = wav.shape[-1]
|
||||
duration = wav_length / sr
|
||||
|
@ -78,16 +92,19 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"):
|
|||
diffusion_conds = torch.stack([ format_diffusion_conditioning(wav.to(device), device=device) ], dim=1)
|
||||
|
||||
codes = dvae.get_codebook_indices( mel_inputs )
|
||||
auto_latent = unified_voice.get_conditioning(autoregressive_conds)
|
||||
|
||||
autoregressive_latent = unified_voice.get_conditioning(autoregressive_conds)
|
||||
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
|
||||
|
||||
return {
|
||||
"codes": codes,
|
||||
"conds": (autoregressive_conds, diffusion_conds),
|
||||
"latent": (autoregressive_latent, diffusion_latent),
|
||||
"original_length": wav_length,
|
||||
"sample_rate": sr,
|
||||
"duration": duration
|
||||
"metadata": {
|
||||
"original_length": wav_length,
|
||||
"sample_rate": sr,
|
||||
"duration": duration
|
||||
}
|
||||
}
|
||||
|
||||
def encode_from_files(paths, device="cuda"):
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
from functools import cache
|
||||
|
||||
from ..arch_utils import TorchMelSpectrogram, TacotronSTFT
|
||||
from .arch_utils import TorchMelSpectrogram, TacotronSTFT
|
||||
|
||||
from .unified_voice import UnifiedVoice
|
||||
from .diffusion import DiffusionTTS
|
||||
|
@ -11,26 +11,45 @@ from .vocoder import UnivNetGenerator
|
|||
from .clvp import CLVP
|
||||
from .dvae import DiscreteVAE
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/')
|
||||
|
||||
# semi-necessary as a way to provide a mechanism for other portions of the program to access models
|
||||
@cache
|
||||
def load_model(name, device="cuda", **kwargs):
|
||||
load_path = None
|
||||
if "autoregressive" in name or "unified_voice" in name:
|
||||
model = UnifiedVoice(**kwargs)
|
||||
load_path = f'{DEFAULT_MODEL_PATH}/autoregressive.pth'
|
||||
elif "diffusion" in name:
|
||||
model = DiffusionTTS(**kwargs)
|
||||
load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth'
|
||||
elif "clvp" in name:
|
||||
model = CLVP(**kwargs)
|
||||
load_path = f'{DEFAULT_MODEL_PATH}/clvp2.pth'
|
||||
elif "vocoder" in name:
|
||||
model = UnivNetGenerator(**kwargs)
|
||||
load_path = f'{DEFAULT_MODEL_PATH}/vocoder.pth'
|
||||
elif "dvae" in name:
|
||||
load_path = f'{DEFAULT_MODEL_PATH}/dvae.pth'
|
||||
model = DiscreteVAE(**kwargs)
|
||||
# to-do: figure out of the below two give the exact same output, since the AR uses #1, the Diffusion uses #2
|
||||
# to-do: figure out of the below two give the exact same output
|
||||
elif "stft" in name:
|
||||
model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs)
|
||||
sr = kwargs.pop("sr")
|
||||
if sr == 24_000:
|
||||
model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs)
|
||||
else:
|
||||
model = TacotronSTFT(**kwargs)
|
||||
elif "tms" in name:
|
||||
model = TorchMelSpectrogram(**kwargs)
|
||||
|
||||
|
||||
model = model.to(device=device)
|
||||
|
||||
if load_path is not None:
|
||||
model.load_state_dict(torch.load(load_path, map_location=device), strict=False)
|
||||
|
||||
return model
|
||||
|
||||
def unload_model():
|
||||
|
|
|
@ -289,7 +289,7 @@ class AudioMiniEncoder(nn.Module):
|
|||
return h[:, :, 0]
|
||||
|
||||
|
||||
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth')
|
||||
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/mel_norms.pth')
|
||||
|
||||
|
||||
class TorchMelSpectrogram(nn.Module):
|
||||
|
@ -463,6 +463,34 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
|||
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||
return x
|
||||
|
||||
TACOTRON_MEL_MAX = 2.3143386840820312
|
||||
TACOTRON_MEL_MIN = -11.512925148010254
|
||||
|
||||
|
||||
def denormalize_tacotron_mel(norm_mel):
|
||||
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
|
||||
|
||||
|
||||
def normalize_tacotron_mel(mel):
|
||||
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
class STFT(torch.nn.Module):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
|
@ -566,9 +594,16 @@ class STFT(torch.nn.Module):
|
|||
return reconstruction
|
||||
|
||||
class TacotronSTFT(torch.nn.Module):
|
||||
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
||||
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
||||
mel_fmax=8000.0):
|
||||
def __init__(
|
||||
self,
|
||||
filter_length=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mel_channels=80,
|
||||
sampling_rate=22050,
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=8000.0
|
||||
):
|
||||
super().__init__()
|
||||
self.n_mel_channels = n_mel_channels
|
||||
self.sampling_rate = sampling_rate
|
||||
|
|
|
@ -119,7 +119,7 @@ class DiscreteVAE(nn.Module):
|
|||
positional_dims = 1, # 2
|
||||
num_tokens = 8192, # 512
|
||||
codebook_dim = 512,
|
||||
num_layers = 3,
|
||||
num_layers = 2, # 3
|
||||
num_resnet_blocks = 3, # 0
|
||||
hidden_dim = 512, # 64
|
||||
channels = 80, # 3
|
||||
|
|
|
@ -2,10 +2,9 @@
|
|||
|
||||
from .config import cfg
|
||||
from .data import create_train_val_dataloader
|
||||
from .emb import qnt
|
||||
from .emb import mel
|
||||
|
||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||
from .data import fold_inputs, unfold_outputs
|
||||
from .utils.distributed import is_global_leader
|
||||
|
||||
import auraloss
|
||||
|
@ -34,7 +33,7 @@ def train_feeder(engine, batch):
|
|||
text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True)
|
||||
text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True)
|
||||
mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True)
|
||||
wav_lengths = pad_sequence([ length for length in batch["wav_lengths"] ], batch_first = True)
|
||||
wav_lengths = pad_sequence([ length for length in batch["wav_length"] ], batch_first = True)
|
||||
|
||||
|
||||
engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths)
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = "0.0.1-dev20240602082927"
|
||||
__version__ = "0.0.1-dev20240617224834"
|
||||
|
|
Loading…
Reference in New Issue
Block a user