encoding mel tokens + dataset preparation
This commit is contained in:
parent
37ec9f1b79
commit
d7b63d2f70
BIN
data/mel_norms.pth
Normal file
BIN
data/mel_norms.pth
Normal file
Binary file not shown.
1
data/tokenizer.json
Normal file
1
data/tokenizer.json
Normal file
|
@ -0,0 +1 @@
|
||||||
|
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}
|
|
@ -1,96 +0,0 @@
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from pathlib import Path
|
|
||||||
from vall_e.emb.g2p import encode as valle_phonemize
|
|
||||||
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
|
||||||
|
|
||||||
device = "cuda"
|
|
||||||
|
|
||||||
target = "in"
|
|
||||||
|
|
||||||
audio_map = {}
|
|
||||||
text_map = {}
|
|
||||||
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
for season in os.listdir(f"./{target}/"):
|
|
||||||
if not os.path.isdir(f"./{target}/{season}/"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for episode in os.listdir(f"./{target}/{season}/"):
|
|
||||||
if not os.path.isdir(f"./{target}/{season}/{episode}/"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for filename in os.listdir(f"./{target}/{season}/{episode}/"):
|
|
||||||
path = f'./{target}/{season}/{episode}/{filename}'
|
|
||||||
attrs = filename.split("_")
|
|
||||||
timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s'
|
|
||||||
|
|
||||||
key = f'{episode}_{timestamp}'
|
|
||||||
|
|
||||||
if filename[-5:] == ".flac":
|
|
||||||
name = attrs[3]
|
|
||||||
emotion = attrs[4]
|
|
||||||
quality = attrs[5]
|
|
||||||
|
|
||||||
audio_map[key] = {
|
|
||||||
"path": path,
|
|
||||||
'episode': episode,
|
|
||||||
"name": name,
|
|
||||||
"emotion": emotion,
|
|
||||||
"quality": quality,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
}
|
|
||||||
|
|
||||||
elif filename[-4:] == ".txt":
|
|
||||||
text_map[key] = open(path, encoding="utf-8").read()
|
|
||||||
txts = {}
|
|
||||||
wavs = []
|
|
||||||
|
|
||||||
for key, entry in audio_map.items():
|
|
||||||
path = entry['path']
|
|
||||||
name = entry['name']
|
|
||||||
emotion = entry['emotion']
|
|
||||||
quality = entry['quality']
|
|
||||||
episode = entry['episode']
|
|
||||||
path = entry['path']
|
|
||||||
timestamp = entry['timestamp']
|
|
||||||
transcription = text_map[key]
|
|
||||||
if name not in data:
|
|
||||||
data[name] = {}
|
|
||||||
os.makedirs(f'./training/{name}/', exist_ok=True)
|
|
||||||
os.makedirs(f'./voices/{name}/', exist_ok=True)
|
|
||||||
|
|
||||||
key = f'{episode}_{timestamp}.flac'
|
|
||||||
os.rename(path, f'./voices/{name}/{key}')
|
|
||||||
|
|
||||||
data[name][key] = {
|
|
||||||
"segments": [],
|
|
||||||
"language": "en",
|
|
||||||
"text": transcription,
|
|
||||||
"misc": {
|
|
||||||
"emotion": emotion,
|
|
||||||
"quality": quality,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"episode": episode,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
path = f'./voices/{name}/{key}'
|
|
||||||
txts[path] = transcription
|
|
||||||
wavs.append(Path(path))
|
|
||||||
|
|
||||||
for name in data.keys():
|
|
||||||
open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) )
|
|
||||||
|
|
||||||
for key, text in tqdm(txts.items(), desc="Phonemizing..."):
|
|
||||||
path = Path(key)
|
|
||||||
phones = valle_phonemize(text)
|
|
||||||
open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
|
|
||||||
|
|
||||||
for path in tqdm(wavs, desc="Quantizing..."):
|
|
||||||
qnt = valle_quantize(path, device=device)
|
|
||||||
torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt"))
|
|
|
@ -5,7 +5,7 @@ import torchaudio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from vall_e.config import cfg
|
from tortoise_tts.config import cfg
|
||||||
|
|
||||||
# things that could be args
|
# things that could be args
|
||||||
cfg.sample_rate = 24_000
|
cfg.sample_rate = 24_000
|
||||||
|
@ -16,15 +16,14 @@ cfg.inference.dtype = torch.bfloat16
|
||||||
cfg.inference.amp = True
|
cfg.inference.amp = True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from vall_e.emb.g2p import encode as valle_phonemize
|
from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension
|
||||||
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
|
|
||||||
|
|
||||||
input_audio = "voices"
|
input_audio = "voices"
|
||||||
input_metadata = "metadata"
|
input_metadata = "metadata"
|
||||||
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}"
|
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}"
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
audio_extension = ".mel"
|
||||||
|
|
||||||
slice = "auto"
|
slice = "auto"
|
||||||
missing = {
|
missing = {
|
||||||
|
@ -57,30 +56,8 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
waveform, sample_rate = torchaudio.load(inpath)
|
waveform, sample_rate = torchaudio.load(inpath)
|
||||||
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device)
|
||||||
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel )
|
||||||
if cfg.inference.audio_backend == "dac":
|
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
|
||||||
"metadata": {
|
|
||||||
"original_length": qnt.original_length,
|
|
||||||
"sample_rate": qnt.sample_rate,
|
|
||||||
|
|
||||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
|
||||||
"chunk_length": qnt.chunk_length,
|
|
||||||
"channels": qnt.channels,
|
|
||||||
"padding": qnt.padding,
|
|
||||||
"dac_version": "1.0.0",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
|
||||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
|
||||||
"metadata": {
|
|
||||||
"original_length": waveform.shape[-1],
|
|
||||||
"sample_rate": sample_rate,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -177,45 +154,16 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||||
))
|
))
|
||||||
|
|
||||||
if len(wavs) > 0:
|
if len(wavs) > 0:
|
||||||
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
|
for job in tqdm(wavs, desc=f"Encoding: {speaker_id}"):
|
||||||
try:
|
try:
|
||||||
outpath, text, language, waveform, sample_rate = job
|
outpath, text, language, waveform, sample_rate = job
|
||||||
|
|
||||||
phones = valle_phonemize(text)
|
phones = valle_phonemize(text)
|
||||||
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
|
mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device)
|
||||||
|
mel["text"] = text
|
||||||
if cfg.inference.audio_backend == "dac":
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel)
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
|
||||||
"metadata": {
|
|
||||||
"original_length": qnt.original_length,
|
|
||||||
"sample_rate": qnt.sample_rate,
|
|
||||||
|
|
||||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
|
||||||
"chunk_length": qnt.chunk_length,
|
|
||||||
"channels": qnt.channels,
|
|
||||||
"padding": qnt.padding,
|
|
||||||
"dac_version": "1.0.0",
|
|
||||||
|
|
||||||
"text": text.strip(),
|
|
||||||
"phonemes": "".join(phones),
|
|
||||||
"language": language,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
|
||||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
|
||||||
"metadata": {
|
|
||||||
"original_length": waveform.shape[-1],
|
|
||||||
"sample_rate": sample_rate,
|
|
||||||
|
|
||||||
"text": text.strip(),
|
|
||||||
"phonemes": "".join(phones),
|
|
||||||
"language": language,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to quantize: {outpath}:", e)
|
print(f"Failed to encode: {outpath}:", e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
|
||||||
|
|
|
@ -18,9 +18,9 @@ cfg.inference.amp = True
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from vall_e.emb.g2p import encode as valle_phonemize
|
from vall_e.emb.g2p import encode as valle_phonemize
|
||||||
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension
|
||||||
|
|
||||||
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
|
audio_extension = ".mel"
|
||||||
|
|
||||||
input_dataset = "LibriTTS_R"
|
input_dataset = "LibriTTS_R"
|
||||||
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
|
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
|
||||||
|
@ -56,52 +56,10 @@ for dataset_name in os.listdir(f'./{input_dataset}/'):
|
||||||
for paths in tqdm(txts, desc="Processing..."):
|
for paths in tqdm(txts, desc="Processing..."):
|
||||||
inpath, outpath = paths
|
inpath, outpath = paths
|
||||||
try:
|
try:
|
||||||
if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists():
|
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
|
||||||
data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read())
|
|
||||||
qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True)
|
|
||||||
|
|
||||||
if not isinstance(data["phonemes"], str):
|
mel = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
|
||||||
data["phonemes"] = "".join(data["phonemes"])
|
mel["text"] = text
|
||||||
|
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel)
|
||||||
for k, v in data.items():
|
|
||||||
qnt[()]['metadata'][k] = v
|
|
||||||
|
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt)
|
|
||||||
else:
|
|
||||||
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
|
|
||||||
|
|
||||||
phones = valle_phonemize(text)
|
|
||||||
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
|
|
||||||
|
|
||||||
if cfg.inference.audio_backend == "dac":
|
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
|
||||||
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
|
|
||||||
"metadata": {
|
|
||||||
"original_length": qnt.original_length,
|
|
||||||
"sample_rate": qnt.sample_rate,
|
|
||||||
|
|
||||||
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
|
|
||||||
"chunk_length": qnt.chunk_length,
|
|
||||||
"channels": qnt.channels,
|
|
||||||
"padding": qnt.padding,
|
|
||||||
"dac_version": "1.0.0",
|
|
||||||
|
|
||||||
"text": text.strip(),
|
|
||||||
"phonemes": "".join(phones),
|
|
||||||
"language": "en",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
|
|
||||||
"codes": qnt.cpu().numpy().astype(np.uint16),
|
|
||||||
"metadata": {
|
|
||||||
"original_length": qnt.shape[-1] / 75.0,
|
|
||||||
"sample_rate": cfg.sample_rate,
|
|
||||||
|
|
||||||
"text": text.strip(),
|
|
||||||
"phonemes": "".join(phones),
|
|
||||||
"language": "en",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tqdm.write(f"Failed to process: {paths}: {e}")
|
tqdm.write(f"Failed to process: {paths}: {e}")
|
||||||
|
|
3
setup.py
3
setup.py
|
@ -54,6 +54,9 @@ setup(
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
"transformers",
|
"transformers",
|
||||||
|
|
||||||
|
#
|
||||||
|
"rotary_embedding_torch",
|
||||||
|
|
||||||
# training bloat
|
# training bloat
|
||||||
"auraloss[all]", # [all] is needed for MelSTFTLoss
|
"auraloss[all]", # [all] is needed for MelSTFTLoss
|
||||||
"h5py",
|
"h5py",
|
||||||
|
|
|
@ -502,7 +502,7 @@ class Optimizations:
|
||||||
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
|
||||||
|
|
||||||
bitsandbytes: bool = False # use bitsandbytes
|
bitsandbytes: bool = False # use bitsandbytes
|
||||||
dadaptation: bool = True # use dadaptation optimizer
|
dadaptation: bool = False # use dadaptation optimizer
|
||||||
bitnet: bool = False # use bitnet
|
bitnet: bool = False # use bitnet
|
||||||
fp8: bool = False # use fp8
|
fp8: bool = False # use fp8
|
||||||
|
|
||||||
|
@ -525,6 +525,7 @@ class Config(BaseConfig):
|
||||||
tokenizer: str = "./tokenizer.json"
|
tokenizer: str = "./tokenizer.json"
|
||||||
|
|
||||||
sample_rate: int = 24_000
|
sample_rate: int = 24_000
|
||||||
|
audio_backend: str = "mel"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
|
|
|
@ -169,8 +169,10 @@ def _get_paths_of_extensions( path, extensions=_get_mel_extension(), validate=Fa
|
||||||
def _load_mels(path, return_metadata=False) -> Tensor:
|
def _load_mels(path, return_metadata=False) -> Tensor:
|
||||||
mel = np.load(_get_mel_path(path), allow_pickle=True)[()]
|
mel = np.load(_get_mel_path(path), allow_pickle=True)[()]
|
||||||
if return_metadata:
|
if return_metadata:
|
||||||
return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16), mel["metadata"]
|
mel["metadata"]["text"] = mel["text"]
|
||||||
return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16)
|
|
||||||
|
return mel["codes"].to(torch.int16), mel["metadata"]
|
||||||
|
return mel["codes"].to(torch.int16)
|
||||||
|
|
||||||
# prune consecutive spaces
|
# prune consecutive spaces
|
||||||
def _cleanup_phones( phones, targets=[" "]):
|
def _cleanup_phones( phones, targets=[" "]):
|
||||||
|
@ -453,37 +455,19 @@ class Dataset(_Dataset):
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prom_length = 0
|
path = random.choice(choices)
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
if cfg.dataset.use_hdf5:
|
||||||
|
key = _get_hdf5_path(path)
|
||||||
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
if "audio" not in cfg.hdf5[key]:
|
||||||
path = random.choice(choices)
|
_logger.warning(f'MISSING AUDIO: {key}')
|
||||||
if cfg.dataset.use_hdf5:
|
return
|
||||||
key = _get_hdf5_path(path)
|
|
||||||
|
|
||||||
if "audio" not in cfg.hdf5[key]:
|
# audio / cond / latents
|
||||||
_logger.warning(f'MISSING AUDIO: {key}')
|
# parameter names and documentation are weird
|
||||||
continue
|
prom = torch.from_numpy(cfg.hdf5[key]["cond"]).to(torch.int16)
|
||||||
|
else:
|
||||||
mel = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
|
prom = _load_mels(path, return_metadata=False)
|
||||||
else:
|
|
||||||
mel = _load_mels(path, return_metadata=False)
|
|
||||||
|
|
||||||
if 0 < trim_length and trim_length < mel.shape[0]:
|
|
||||||
mel = trim( mel, trim_length )
|
|
||||||
|
|
||||||
prom_list.append(mel)
|
|
||||||
prom_length += mel.shape[0]
|
|
||||||
|
|
||||||
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
|
|
||||||
break
|
|
||||||
|
|
||||||
# might be better to decode => concat waveforms with silence in between => reencode
|
|
||||||
# as you technically can't just append encodec sequences together like this without issues
|
|
||||||
prom = torch.cat(prom_list)
|
|
||||||
|
|
||||||
if 0 < trim_length and trim_length < prom.shape[0]:
|
|
||||||
prom = trim( prom, trim_length )
|
|
||||||
|
|
||||||
return prom
|
return prom
|
||||||
|
|
||||||
|
@ -507,15 +491,36 @@ class Dataset(_Dataset):
|
||||||
spkr_group = self.get_speaker_group(path)
|
spkr_group = self.get_speaker_group(path)
|
||||||
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
#spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||||
|
|
||||||
|
if cfg.dataset.use_hdf5:
|
||||||
|
key = _get_hdf5_path(path)
|
||||||
|
|
||||||
|
if key not in cfg.hdf5:
|
||||||
|
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
||||||
|
|
||||||
|
text = cfg.hdf5[key]["text"][:]
|
||||||
|
mel = cfg.hdf5[key]["audio"][:]
|
||||||
|
latents = cfg.hdf5[key]["latents"][:]
|
||||||
|
|
||||||
|
text = torch.from_numpy(text).to(self.text_dtype)
|
||||||
|
mel = torch.from_numpy(mel).to(torch.int16)
|
||||||
|
latents = torch.from_numpy(latents)
|
||||||
|
wav_length = cfg.hdf5[key].attrs["wav_length"]
|
||||||
|
else:
|
||||||
|
mel, metadata = _load_mels(path, return_metadata=True)
|
||||||
|
text = torch.tensor(metadata["text"]).to(self.text_dtype)
|
||||||
|
latents = torch.from_numpy(metadata["latent"][0])
|
||||||
|
wav_length = metadata["wav_length"]
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
index=index,
|
index=index,
|
||||||
path=Path(path),
|
path=Path(path),
|
||||||
spkr_name=spkr_name,
|
spkr_name=spkr_name,
|
||||||
spkr_id=spkr_id,
|
spkr_id=spkr_id,
|
||||||
#text=text,
|
|
||||||
#proms=proms,
|
latents=latents,
|
||||||
#resps=resps,
|
text=text,
|
||||||
|
mel=mel,
|
||||||
|
wav_length=wav_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
def head_(self, n):
|
def head_(self, n):
|
||||||
|
@ -603,12 +608,14 @@ def create_train_val_dataloader():
|
||||||
return train_dl, subtrain_dl, val_dl
|
return train_dl, subtrain_dl, val_dl
|
||||||
|
|
||||||
def unpack_audio( npz ):
|
def unpack_audio( npz ):
|
||||||
mel = torch.from_numpy(npz["codes"].astype(int))[0].t().to(dtype=torch.int16)
|
mel = npz["codes"].to(dtype=torch.int16, device="cpu")
|
||||||
|
conds = npz["conds"][0].to(dtype=torch.int16, device="cpu")
|
||||||
|
latent = npz["latent"][0].to(dtype=torch.int16, device="cpu")
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
|
||||||
if "text" in npz["metadata"]:
|
if "text" in npz:
|
||||||
metadata["text"] = npz["metadata"]["text"]
|
metadata["text"] = npz["text"]
|
||||||
|
|
||||||
if "phonemes" in npz["metadata"]:
|
if "phonemes" in npz["metadata"]:
|
||||||
metadata["phonemes"] = npz["metadata"]["phonemes"]
|
metadata["phonemes"] = npz["metadata"]["phonemes"]
|
||||||
|
@ -616,10 +623,15 @@ def unpack_audio( npz ):
|
||||||
if "language" in npz["metadata"]:
|
if "language" in npz["metadata"]:
|
||||||
metadata["language"] = npz["metadata"]["language"]
|
metadata["language"] = npz["metadata"]["language"]
|
||||||
|
|
||||||
if "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]:
|
if "original_length" in npz["metadata"]:
|
||||||
|
metadata["wav_length"] = npz["metadata"]["original_length"]
|
||||||
|
|
||||||
|
if "duration" in npz["metadata"]:
|
||||||
|
metadata["duration"] = npz["metadata"]["duration"]
|
||||||
|
elif "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]:
|
||||||
metadata["duration"] = npz["metadata"]["original_length"] / npz["metadata"]["sample_rate"]
|
metadata["duration"] = npz["metadata"]["original_length"] / npz["metadata"]["sample_rate"]
|
||||||
|
|
||||||
return mel, metadata
|
return mel, conds, latent, metadata
|
||||||
|
|
||||||
# parse dataset into better to sample metadata
|
# parse dataset into better to sample metadata
|
||||||
def create_dataset_metadata( skip_existing=True ):
|
def create_dataset_metadata( skip_existing=True ):
|
||||||
|
@ -672,7 +684,7 @@ def create_dataset_metadata( skip_existing=True ):
|
||||||
if audios:
|
if audios:
|
||||||
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
|
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
|
||||||
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
|
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
|
||||||
mel, utterance_metadata = unpack_audio( npz )
|
mel, conds, latents, utterance_metadata = unpack_audio( npz )
|
||||||
# text
|
# text
|
||||||
if texts and text_exists and not utterance_metadata:
|
if texts and text_exists and not utterance_metadata:
|
||||||
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
|
||||||
|
@ -755,11 +767,17 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
# audio
|
# audio
|
||||||
if audios:
|
if audios:
|
||||||
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
|
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
|
||||||
mel, utterance_metadata = unpack_audio( npz )
|
mel, conds, latents, utterance_metadata = unpack_audio( npz )
|
||||||
|
|
||||||
if "audio" not in group:
|
if "audio" not in group:
|
||||||
group.create_dataset('audio', data=mel.numpy().astype(np.int16), compression='lzf')
|
group.create_dataset('audio', data=mel.numpy().astype(np.int16), compression='lzf')
|
||||||
|
|
||||||
|
if "conds" not in group:
|
||||||
|
group.create_dataset('conds', data=conds.numpy().astype(np.int16), compression='lzf')
|
||||||
|
|
||||||
|
if "latents" not in group:
|
||||||
|
group.create_dataset('latents', data=latents.numpy().astype(np.int16), compression='lzf')
|
||||||
|
|
||||||
# text
|
# text
|
||||||
if texts:
|
if texts:
|
||||||
if not utterance_metadata and text_exists:
|
if not utterance_metadata and text_exists:
|
||||||
|
@ -778,6 +796,7 @@ def create_dataset_hdf5( skip_existing=True ):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tqdm.write(f'Error while processing {id}: {e}')
|
tqdm.write(f'Error while processing {id}: {e}')
|
||||||
|
raise e
|
||||||
|
|
||||||
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
with open(str(metadata_path), "w", encoding="utf-8") as f:
|
||||||
f.write( json.dumps( metadata ) )
|
f.write( json.dumps( metadata ) )
|
||||||
|
@ -837,28 +856,10 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
samples = {
|
samples = {
|
||||||
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
|
||||||
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
|
||||||
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
|
||||||
}
|
}
|
||||||
|
|
||||||
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
for k, v in samples.items():
|
|
||||||
for i in range(len(v)):
|
|
||||||
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
|
|
||||||
try:
|
|
||||||
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
|
|
||||||
"""
|
|
||||||
v[i]['proms'][j] = v[i]['proms'][j].shape
|
|
||||||
v[i]['resps'][j] = v[i]['resps'][j].shape
|
|
||||||
|
|
||||||
for k, v in samples.items():
|
for k, v in samples.items():
|
||||||
for i in range(len(v)):
|
for i in range(len(v)):
|
||||||
print(f'{k}[{i}]:', v[i])
|
print(f'{k}[{i}]:', v[i])
|
||||||
|
|
|
@ -15,6 +15,19 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from ..models import load_model, unload_model
|
from ..models import load_model, unload_model
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def pad_or_truncate(t, length):
|
||||||
|
"""
|
||||||
|
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
|
||||||
|
"""
|
||||||
|
if t.shape[-1] == length:
|
||||||
|
return t
|
||||||
|
elif t.shape[-1] < length:
|
||||||
|
return F.pad(t, (0, length-t.shape[-1]))
|
||||||
|
else:
|
||||||
|
return t[..., :length]
|
||||||
|
|
||||||
# decodes mel spectrogram into a wav
|
# decodes mel spectrogram into a wav
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def decode(codes: Tensor, device="cuda"):
|
def decode(codes: Tensor, device="cuda"):
|
||||||
|
@ -34,30 +47,30 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
|
||||||
def _replace_file_extension(path, suffix):
|
def _replace_file_extension(path, suffix):
|
||||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||||
|
|
||||||
def format_autoregressive_conditioning( wav, cond_length=132300, device ):
|
def format_autoregressive_conditioning( wav, cond_length=132300, device="cuda" ):
|
||||||
"""
|
"""
|
||||||
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
|
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
|
||||||
"""
|
"""
|
||||||
model = load_model("tms", device=device)
|
model = load_model("tms", device=device)
|
||||||
|
|
||||||
gap = wav.shape[-1] - cond_length
|
if cond_length > 0:
|
||||||
|
gap = wav.shape[-1] - cond_length
|
||||||
if gap < 0:
|
if gap < 0:
|
||||||
wav = F.pad(wav, pad=(0, abs(gap)))
|
wav = F.pad(wav, pad=(0, abs(gap)))
|
||||||
elif gap > 0:
|
elif gap > 0:
|
||||||
rand_start = random.randint(0, gap)
|
rand_start = random.randint(0, gap)
|
||||||
wav = wav[:, rand_start:rand_start + cond_length]
|
wav = wav[:, rand_start:rand_start + cond_length]
|
||||||
|
|
||||||
mel_clip = model(wav.unsqueeze(0)).squeeze(0) # ???
|
mel_clip = model(wav.unsqueeze(0)).squeeze(0) # ???
|
||||||
return mel_clip.unsqueeze(0).to(device) # ???
|
return mel_clip.unsqueeze(0).to(device) # ???
|
||||||
|
|
||||||
def format_diffusion_conditioning( sample, device, do_normalization=False ):
|
def format_diffusion_conditioning( sample, device, do_normalization=False ):
|
||||||
model = load_model("stft", device=device)
|
model = load_model("stft", device=device, sr=24_000)
|
||||||
|
|
||||||
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
sample = torchaudio.functional.resample(sample, 22050, 24000)
|
||||||
sample = pad_or_truncate(sample, 102400)
|
sample = pad_or_truncate(sample, 102400)
|
||||||
sample = sample.to(device)
|
sample = sample.to(device)
|
||||||
mel = model.mel_spectrogram(wav)
|
mel = model.mel_spectrogram(sample)
|
||||||
"""
|
"""
|
||||||
if do_normalization:
|
if do_normalization:
|
||||||
mel = normalize_tacotron_mel(mel)
|
mel = normalize_tacotron_mel(mel)
|
||||||
|
@ -70,6 +83,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"):
|
||||||
dvae = load_model("dvae", device=device)
|
dvae = load_model("dvae", device=device)
|
||||||
unified_voice = load_model("unified_voice", device=device)
|
unified_voice = load_model("unified_voice", device=device)
|
||||||
diffusion = load_model("diffusion", device=device)
|
diffusion = load_model("diffusion", device=device)
|
||||||
|
mel_inputs = format_autoregressive_conditioning( wav, 0, device )
|
||||||
|
|
||||||
wav_length = wav.shape[-1]
|
wav_length = wav.shape[-1]
|
||||||
duration = wav_length / sr
|
duration = wav_length / sr
|
||||||
|
@ -78,16 +92,19 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"):
|
||||||
diffusion_conds = torch.stack([ format_diffusion_conditioning(wav.to(device), device=device) ], dim=1)
|
diffusion_conds = torch.stack([ format_diffusion_conditioning(wav.to(device), device=device) ], dim=1)
|
||||||
|
|
||||||
codes = dvae.get_codebook_indices( mel_inputs )
|
codes = dvae.get_codebook_indices( mel_inputs )
|
||||||
auto_latent = unified_voice.get_conditioning(autoregressive_conds)
|
|
||||||
|
autoregressive_latent = unified_voice.get_conditioning(autoregressive_conds)
|
||||||
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
|
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"codes": codes,
|
"codes": codes,
|
||||||
"conds": (autoregressive_conds, diffusion_conds),
|
"conds": (autoregressive_conds, diffusion_conds),
|
||||||
"latent": (autoregressive_latent, diffusion_latent),
|
"latent": (autoregressive_latent, diffusion_latent),
|
||||||
"original_length": wav_length,
|
"metadata": {
|
||||||
"sample_rate": sr,
|
"original_length": wav_length,
|
||||||
"duration": duration
|
"sample_rate": sr,
|
||||||
|
"duration": duration
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def encode_from_files(paths, device="cuda"):
|
def encode_from_files(paths, device="cuda"):
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
|
|
||||||
from functools import cache
|
from functools import cache
|
||||||
|
|
||||||
from ..arch_utils import TorchMelSpectrogram, TacotronSTFT
|
from .arch_utils import TorchMelSpectrogram, TacotronSTFT
|
||||||
|
|
||||||
from .unified_voice import UnifiedVoice
|
from .unified_voice import UnifiedVoice
|
||||||
from .diffusion import DiffusionTTS
|
from .diffusion import DiffusionTTS
|
||||||
|
@ -11,26 +11,45 @@ from .vocoder import UnivNetGenerator
|
||||||
from .clvp import CLVP
|
from .clvp import CLVP
|
||||||
from .dvae import DiscreteVAE
|
from .dvae import DiscreteVAE
|
||||||
|
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/')
|
||||||
|
|
||||||
# semi-necessary as a way to provide a mechanism for other portions of the program to access models
|
# semi-necessary as a way to provide a mechanism for other portions of the program to access models
|
||||||
@cache
|
@cache
|
||||||
def load_model(name, device="cuda", **kwargs):
|
def load_model(name, device="cuda", **kwargs):
|
||||||
|
load_path = None
|
||||||
if "autoregressive" in name or "unified_voice" in name:
|
if "autoregressive" in name or "unified_voice" in name:
|
||||||
model = UnifiedVoice(**kwargs)
|
model = UnifiedVoice(**kwargs)
|
||||||
|
load_path = f'{DEFAULT_MODEL_PATH}/autoregressive.pth'
|
||||||
elif "diffusion" in name:
|
elif "diffusion" in name:
|
||||||
model = DiffusionTTS(**kwargs)
|
model = DiffusionTTS(**kwargs)
|
||||||
|
load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth'
|
||||||
elif "clvp" in name:
|
elif "clvp" in name:
|
||||||
model = CLVP(**kwargs)
|
model = CLVP(**kwargs)
|
||||||
|
load_path = f'{DEFAULT_MODEL_PATH}/clvp2.pth'
|
||||||
elif "vocoder" in name:
|
elif "vocoder" in name:
|
||||||
model = UnivNetGenerator(**kwargs)
|
model = UnivNetGenerator(**kwargs)
|
||||||
|
load_path = f'{DEFAULT_MODEL_PATH}/vocoder.pth'
|
||||||
elif "dvae" in name:
|
elif "dvae" in name:
|
||||||
|
load_path = f'{DEFAULT_MODEL_PATH}/dvae.pth'
|
||||||
model = DiscreteVAE(**kwargs)
|
model = DiscreteVAE(**kwargs)
|
||||||
# to-do: figure out of the below two give the exact same output, since the AR uses #1, the Diffusion uses #2
|
# to-do: figure out of the below two give the exact same output
|
||||||
elif "stft" in name:
|
elif "stft" in name:
|
||||||
model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs)
|
sr = kwargs.pop("sr")
|
||||||
|
if sr == 24_000:
|
||||||
|
model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs)
|
||||||
|
else:
|
||||||
|
model = TacotronSTFT(**kwargs)
|
||||||
elif "tms" in name:
|
elif "tms" in name:
|
||||||
model = TorchMelSpectrogram(**kwargs)
|
model = TorchMelSpectrogram(**kwargs)
|
||||||
|
|
||||||
model = model.to(device=device)
|
model = model.to(device=device)
|
||||||
|
|
||||||
|
if load_path is not None:
|
||||||
|
model.load_state_dict(torch.load(load_path, map_location=device), strict=False)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
|
|
|
@ -289,7 +289,7 @@ class AudioMiniEncoder(nn.Module):
|
||||||
return h[:, :, 0]
|
return h[:, :, 0]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth')
|
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/mel_norms.pth')
|
||||||
|
|
||||||
|
|
||||||
class TorchMelSpectrogram(nn.Module):
|
class TorchMelSpectrogram(nn.Module):
|
||||||
|
@ -463,6 +463,34 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
||||||
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
TACOTRON_MEL_MAX = 2.3143386840820312
|
||||||
|
TACOTRON_MEL_MIN = -11.512925148010254
|
||||||
|
|
||||||
|
|
||||||
|
def denormalize_tacotron_mel(norm_mel):
|
||||||
|
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_tacotron_mel(mel):
|
||||||
|
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||||
|
"""
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
C: compression factor
|
||||||
|
"""
|
||||||
|
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_decompression(x, C=1):
|
||||||
|
"""
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
C: compression factor used to compress
|
||||||
|
"""
|
||||||
|
return torch.exp(x) / C
|
||||||
|
|
||||||
class STFT(torch.nn.Module):
|
class STFT(torch.nn.Module):
|
||||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||||
|
@ -566,9 +594,16 @@ class STFT(torch.nn.Module):
|
||||||
return reconstruction
|
return reconstruction
|
||||||
|
|
||||||
class TacotronSTFT(torch.nn.Module):
|
class TacotronSTFT(torch.nn.Module):
|
||||||
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
def __init__(
|
||||||
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
self,
|
||||||
mel_fmax=8000.0):
|
filter_length=1024,
|
||||||
|
hop_length=256,
|
||||||
|
win_length=1024,
|
||||||
|
n_mel_channels=80,
|
||||||
|
sampling_rate=22050,
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=8000.0
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_mel_channels = n_mel_channels
|
self.n_mel_channels = n_mel_channels
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
|
|
|
@ -119,7 +119,7 @@ class DiscreteVAE(nn.Module):
|
||||||
positional_dims = 1, # 2
|
positional_dims = 1, # 2
|
||||||
num_tokens = 8192, # 512
|
num_tokens = 8192, # 512
|
||||||
codebook_dim = 512,
|
codebook_dim = 512,
|
||||||
num_layers = 3,
|
num_layers = 2, # 3
|
||||||
num_resnet_blocks = 3, # 0
|
num_resnet_blocks = 3, # 0
|
||||||
hidden_dim = 512, # 64
|
hidden_dim = 512, # 64
|
||||||
channels = 80, # 3
|
channels = 80, # 3
|
||||||
|
|
|
@ -2,10 +2,9 @@
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .data import create_train_val_dataloader
|
from .data import create_train_val_dataloader
|
||||||
from .emb import qnt
|
from .emb import mel
|
||||||
|
|
||||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||||
from .data import fold_inputs, unfold_outputs
|
|
||||||
from .utils.distributed import is_global_leader
|
from .utils.distributed import is_global_leader
|
||||||
|
|
||||||
import auraloss
|
import auraloss
|
||||||
|
@ -34,7 +33,7 @@ def train_feeder(engine, batch):
|
||||||
text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True)
|
text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True)
|
||||||
text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True)
|
text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True)
|
||||||
mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True)
|
mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True)
|
||||||
wav_lengths = pad_sequence([ length for length in batch["wav_lengths"] ], batch_first = True)
|
wav_lengths = pad_sequence([ length for length in batch["wav_length"] ], batch_first = True)
|
||||||
|
|
||||||
|
|
||||||
engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths)
|
engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths)
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
__version__ = "0.0.1-dev20240602082927"
|
__version__ = "0.0.1-dev20240617224834"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user