encoding mel tokens + dataset preparation

This commit is contained in:
mrq 2024-06-18 10:30:54 -05:00
parent 37ec9f1b79
commit d7b63d2f70
14 changed files with 182 additions and 296 deletions

BIN
data/mel_norms.pth Normal file

Binary file not shown.

1
data/tokenizer.json Normal file
View File

@ -0,0 +1 @@
{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}}

View File

@ -1,96 +0,0 @@
import os
import json
import torch
from tqdm.auto import tqdm
from pathlib import Path
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
device = "cuda"
target = "in"
audio_map = {}
text_map = {}
data = {}
for season in os.listdir(f"./{target}/"):
if not os.path.isdir(f"./{target}/{season}/"):
continue
for episode in os.listdir(f"./{target}/{season}/"):
if not os.path.isdir(f"./{target}/{season}/{episode}/"):
continue
for filename in os.listdir(f"./{target}/{season}/{episode}/"):
path = f'./{target}/{season}/{episode}/{filename}'
attrs = filename.split("_")
timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s'
key = f'{episode}_{timestamp}'
if filename[-5:] == ".flac":
name = attrs[3]
emotion = attrs[4]
quality = attrs[5]
audio_map[key] = {
"path": path,
'episode': episode,
"name": name,
"emotion": emotion,
"quality": quality,
"timestamp": timestamp,
}
elif filename[-4:] == ".txt":
text_map[key] = open(path, encoding="utf-8").read()
txts = {}
wavs = []
for key, entry in audio_map.items():
path = entry['path']
name = entry['name']
emotion = entry['emotion']
quality = entry['quality']
episode = entry['episode']
path = entry['path']
timestamp = entry['timestamp']
transcription = text_map[key]
if name not in data:
data[name] = {}
os.makedirs(f'./training/{name}/', exist_ok=True)
os.makedirs(f'./voices/{name}/', exist_ok=True)
key = f'{episode}_{timestamp}.flac'
os.rename(path, f'./voices/{name}/{key}')
data[name][key] = {
"segments": [],
"language": "en",
"text": transcription,
"misc": {
"emotion": emotion,
"quality": quality,
"timestamp": timestamp,
"episode": episode,
}
}
path = f'./voices/{name}/{key}'
txts[path] = transcription
wavs.append(Path(path))
for name in data.keys():
open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) )
for key, text in tqdm(txts.items(), desc="Phonemizing..."):
path = Path(key)
phones = valle_phonemize(text)
open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
for path in tqdm(wavs, desc="Quantizing..."):
qnt = valle_quantize(path, device=device)
torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt"))

View File

@ -5,7 +5,7 @@ import torchaudio
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
from vall_e.config import cfg
from tortoise_tts.config import cfg
# things that could be args
cfg.sample_rate = 24_000
@ -16,15 +16,14 @@ cfg.inference.dtype = torch.bfloat16
cfg.inference.amp = True
"""
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode as valle_quantize, _replace_file_extension
from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension
input_audio = "voices"
input_metadata = "metadata"
output_dataset = f"training-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz-{cfg.inference.audio_backend}"
device = "cuda"
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
audio_extension = ".mel"
slice = "auto"
missing = {
@ -57,30 +56,8 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
continue
waveform, sample_rate = torchaudio.load(inpath)
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
},
})
mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device)
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel )
continue
@ -177,45 +154,16 @@ for dataset_name in sorted(os.listdir(f'./{input_audio}/')):
))
if len(wavs) > 0:
for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"):
for job in tqdm(wavs, desc=f"Encoding: {speaker_id}"):
try:
outpath, text, language, waveform, sample_rate = job
phones = valle_phonemize(text)
qnt = valle_quantize(waveform, sr=sample_rate, device=device)
if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": waveform.shape[-1],
"sample_rate": sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": language,
},
})
mel = tortoise_mel_encode(waveform, sr=sample_rate, device=device)
mel["text"] = text
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel)
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
print(f"Failed to encode: {outpath}:", e)
continue
open("./missing.json", 'w', encoding='utf-8').write(json.dumps(missing))

View File

@ -18,9 +18,9 @@ cfg.inference.amp = True
"""
from vall_e.emb.g2p import encode as valle_phonemize
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
from tortoise_tts.emb.mel import encode as tortoise_mel_encode, _replace_file_extension
audio_extension = ".dac" if cfg.inference.audio_backend == "dac" else ".enc"
audio_extension = ".mel"
input_dataset = "LibriTTS_R"
output_dataset = f"LibriTTS-Train-{'2' if cfg.sample_rate == 24_000 else '4'}4KHz"
@ -56,52 +56,10 @@ for dataset_name in os.listdir(f'./{input_dataset}/'):
for paths in tqdm(txts, desc="Processing..."):
inpath, outpath = paths
try:
if _replace_file_extension(outpath, ".dac").exists() and _replace_file_extension(outpath, ".json").exists():
data = json.loads(open(_replace_file_extension(outpath, ".json"), 'r', encoding='utf-8').read())
qnt = np.load(_replace_file_extension(outpath, audio_extension), allow_pickle=True)
if not isinstance(data["phonemes"], str):
data["phonemes"] = "".join(data["phonemes"])
for k, v in data.items():
qnt[()]['metadata'][k] = v
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), qnt)
else:
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
phones = valle_phonemize(text)
qnt = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
if cfg.inference.audio_backend == "dac":
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.codes.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.original_length,
"sample_rate": qnt.sample_rate,
"input_db": qnt.input_db.cpu().numpy().astype(np.float32),
"chunk_length": qnt.chunk_length,
"channels": qnt.channels,
"padding": qnt.padding,
"dac_version": "1.0.0",
"text": text.strip(),
"phonemes": "".join(phones),
"language": "en",
},
})
else:
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), {
"codes": qnt.cpu().numpy().astype(np.uint16),
"metadata": {
"original_length": qnt.shape[-1] / 75.0,
"sample_rate": cfg.sample_rate,
"text": text.strip(),
"phonemes": "".join(phones),
"language": "en",
},
})
text = open(_replace_file_extension(inpath, ".original.txt"), "r", encoding="utf-8").read()
mel = valle_quantize(_replace_file_extension(inpath, ".wav"), device=device)
mel["text"] = text
np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), mel)
except Exception as e:
tqdm.write(f"Failed to process: {paths}: {e}")

View File

@ -54,6 +54,9 @@ setup(
"tokenizers",
"transformers",
#
"rotary_embedding_torch",
# training bloat
"auraloss[all]", # [all] is needed for MelSTFTLoss
"h5py",

View File

@ -502,7 +502,7 @@ class Optimizations:
optimizers: bool = True # inject/replace optimizers (BnB, DAdaptation)
bitsandbytes: bool = False # use bitsandbytes
dadaptation: bool = True # use dadaptation optimizer
dadaptation: bool = False # use dadaptation optimizer
bitnet: bool = False # use bitnet
fp8: bool = False # use fp8
@ -525,6 +525,7 @@ class Config(BaseConfig):
tokenizer: str = "./tokenizer.json"
sample_rate: int = 24_000
audio_backend: str = "mel"
@property
def model(self):

View File

@ -169,8 +169,10 @@ def _get_paths_of_extensions( path, extensions=_get_mel_extension(), validate=Fa
def _load_mels(path, return_metadata=False) -> Tensor:
mel = np.load(_get_mel_path(path), allow_pickle=True)[()]
if return_metadata:
return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16), mel["metadata"]
return torch.from_numpy(mel["codes"].astype(int))[0][:].t().to(torch.int16)
mel["metadata"]["text"] = mel["text"]
return mel["codes"].to(torch.int16), mel["metadata"]
return mel["codes"].to(torch.int16)
# prune consecutive spaces
def _cleanup_phones( phones, targets=[" "]):
@ -453,37 +455,19 @@ class Dataset(_Dataset):
)
"""
prom_length = 0
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
path = random.choice(choices)
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
for _ in range(cfg.dataset.max_prompts):
path = random.choice(choices)
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if "audio" not in cfg.hdf5[key]:
_logger.warning(f'MISSING AUDIO: {key}')
return
if "audio" not in cfg.hdf5[key]:
_logger.warning(f'MISSING AUDIO: {key}')
continue
mel = torch.from_numpy(cfg.hdf5[key]["audio"][:]).to(torch.int16)
else:
mel = _load_mels(path, return_metadata=False)
if 0 < trim_length and trim_length < mel.shape[0]:
mel = trim( mel, trim_length )
prom_list.append(mel)
prom_length += mel.shape[0]
if prom_length >= trim_length or random.random() > cfg.dataset.random_utterance:
break
# might be better to decode => concat waveforms with silence in between => reencode
# as you technically can't just append encodec sequences together like this without issues
prom = torch.cat(prom_list)
if 0 < trim_length and trim_length < prom.shape[0]:
prom = trim( prom, trim_length )
# audio / cond / latents
# parameter names and documentation are weird
prom = torch.from_numpy(cfg.hdf5[key]["cond"]).to(torch.int16)
else:
prom = _load_mels(path, return_metadata=False)
return prom
@ -507,15 +491,36 @@ class Dataset(_Dataset):
spkr_group = self.get_speaker_group(path)
#spkr_group_id = self.spkr_group_symmap[spkr_group]
if cfg.dataset.use_hdf5:
key = _get_hdf5_path(path)
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
text = cfg.hdf5[key]["text"][:]
mel = cfg.hdf5[key]["audio"][:]
latents = cfg.hdf5[key]["latents"][:]
text = torch.from_numpy(text).to(self.text_dtype)
mel = torch.from_numpy(mel).to(torch.int16)
latents = torch.from_numpy(latents)
wav_length = cfg.hdf5[key].attrs["wav_length"]
else:
mel, metadata = _load_mels(path, return_metadata=True)
text = torch.tensor(metadata["text"]).to(self.text_dtype)
latents = torch.from_numpy(metadata["latent"][0])
wav_length = metadata["wav_length"]
return dict(
index=index,
path=Path(path),
spkr_name=spkr_name,
spkr_id=spkr_id,
#text=text,
#proms=proms,
#resps=resps,
latents=latents,
text=text,
mel=mel,
wav_length=wav_length,
)
def head_(self, n):
@ -603,12 +608,14 @@ def create_train_val_dataloader():
return train_dl, subtrain_dl, val_dl
def unpack_audio( npz ):
mel = torch.from_numpy(npz["codes"].astype(int))[0].t().to(dtype=torch.int16)
mel = npz["codes"].to(dtype=torch.int16, device="cpu")
conds = npz["conds"][0].to(dtype=torch.int16, device="cpu")
latent = npz["latent"][0].to(dtype=torch.int16, device="cpu")
metadata = {}
if "text" in npz["metadata"]:
metadata["text"] = npz["metadata"]["text"]
if "text" in npz:
metadata["text"] = npz["text"]
if "phonemes" in npz["metadata"]:
metadata["phonemes"] = npz["metadata"]["phonemes"]
@ -616,10 +623,15 @@ def unpack_audio( npz ):
if "language" in npz["metadata"]:
metadata["language"] = npz["metadata"]["language"]
if "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]:
if "original_length" in npz["metadata"]:
metadata["wav_length"] = npz["metadata"]["original_length"]
if "duration" in npz["metadata"]:
metadata["duration"] = npz["metadata"]["duration"]
elif "original_length" in npz["metadata"] and "sample_rate" in npz["metadata"]:
metadata["duration"] = npz["metadata"]["original_length"] / npz["metadata"]["sample_rate"]
return mel, metadata
return mel, conds, latent, metadata
# parse dataset into better to sample metadata
def create_dataset_metadata( skip_existing=True ):
@ -672,7 +684,7 @@ def create_dataset_metadata( skip_existing=True ):
if audios:
# ideally we'll encode Encodec-based audio in a similar manner because np has smaller files than pt
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
mel, utterance_metadata = unpack_audio( npz )
mel, conds, latents, utterance_metadata = unpack_audio( npz )
# text
if texts and text_exists and not utterance_metadata:
utterance_metadata = json.loads(open(f'{root}/{name}/{id}{_get_phone_extension()}', "r", encoding="utf-8").read())
@ -755,10 +767,16 @@ def create_dataset_hdf5( skip_existing=True ):
# audio
if audios:
npz = np.load(f'{root}/{name}/{id}{_get_mel_extension()}', allow_pickle=True)[()]
mel, utterance_metadata = unpack_audio( npz )
mel, conds, latents, utterance_metadata = unpack_audio( npz )
if "audio" not in group:
group.create_dataset('audio', data=mel.numpy().astype(np.int16), compression='lzf')
if "conds" not in group:
group.create_dataset('conds', data=conds.numpy().astype(np.int16), compression='lzf')
if "latents" not in group:
group.create_dataset('latents', data=latents.numpy().astype(np.int16), compression='lzf')
# text
if texts:
@ -778,6 +796,7 @@ def create_dataset_hdf5( skip_existing=True ):
except Exception as e:
tqdm.write(f'Error while processing {id}: {e}')
raise e
with open(str(metadata_path), "w", encoding="utf-8") as f:
f.write( json.dumps( metadata ) )
@ -837,27 +856,9 @@ if __name__ == "__main__":
samples = {
"training": [ next(iter(train_dl)), next(iter(train_dl)) ],
"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
#"evaluation": [ next(iter(subtrain_dl)), next(iter(subtrain_dl)) ],
#"validation": [ next(iter(val_dl)), next(iter(val_dl)) ],
}
Path("./data/sample-test/").mkdir(parents=True, exist_ok=True)
for k, v in samples.items():
for i in range(len(v)):
for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."):
"""
try:
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
try:
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
"""
v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape
for k, v in samples.items():
for i in range(len(v)):

View File

@ -15,6 +15,19 @@ from tqdm import tqdm
from ..models import load_model, unload_model
import torch.nn.functional as F
def pad_or_truncate(t, length):
"""
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
"""
if t.shape[-1] == length:
return t
elif t.shape[-1] < length:
return F.pad(t, (0, length-t.shape[-1]))
else:
return t[..., :length]
# decodes mel spectrogram into a wav
@torch.inference_mode()
def decode(codes: Tensor, device="cuda"):
@ -34,30 +47,30 @@ def decode_to_file(resps: Tensor, path: Path, device="cuda"):
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def format_autoregressive_conditioning( wav, cond_length=132300, device ):
def format_autoregressive_conditioning( wav, cond_length=132300, device="cuda" ):
"""
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
"""
model = load_model("tms", device=device)
gap = wav.shape[-1] - cond_length
if gap < 0:
wav = F.pad(wav, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
wav = wav[:, rand_start:rand_start + cond_length]
if cond_length > 0:
gap = wav.shape[-1] - cond_length
if gap < 0:
wav = F.pad(wav, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
wav = wav[:, rand_start:rand_start + cond_length]
mel_clip = model(wav.unsqueeze(0)).squeeze(0) # ???
return mel_clip.unsqueeze(0).to(device) # ???
def format_diffusion_conditioning( sample, device, do_normalization=False ):
model = load_model("stft", device=device)
model = load_model("stft", device=device, sr=24_000)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400)
sample = sample.to(device)
mel = model.mel_spectrogram(wav)
mel = model.mel_spectrogram(sample)
"""
if do_normalization:
mel = normalize_tacotron_mel(mel)
@ -70,6 +83,7 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"):
dvae = load_model("dvae", device=device)
unified_voice = load_model("unified_voice", device=device)
diffusion = load_model("diffusion", device=device)
mel_inputs = format_autoregressive_conditioning( wav, 0, device )
wav_length = wav.shape[-1]
duration = wav_length / sr
@ -78,16 +92,19 @@ def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"):
diffusion_conds = torch.stack([ format_diffusion_conditioning(wav.to(device), device=device) ], dim=1)
codes = dvae.get_codebook_indices( mel_inputs )
auto_latent = unified_voice.get_conditioning(autoregressive_conds)
autoregressive_latent = unified_voice.get_conditioning(autoregressive_conds)
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
return {
"codes": codes,
"conds": (autoregressive_conds, diffusion_conds),
"latent": (autoregressive_latent, diffusion_latent),
"original_length": wav_length,
"sample_rate": sr,
"duration": duration
"metadata": {
"original_length": wav_length,
"sample_rate": sr,
"duration": duration
}
}
def encode_from_files(paths, device="cuda"):

View File

@ -3,7 +3,7 @@
from functools import cache
from ..arch_utils import TorchMelSpectrogram, TacotronSTFT
from .arch_utils import TorchMelSpectrogram, TacotronSTFT
from .unified_voice import UnifiedVoice
from .diffusion import DiffusionTTS
@ -11,26 +11,45 @@ from .vocoder import UnivNetGenerator
from .clvp import CLVP
from .dvae import DiscreteVAE
import os
import torch
DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/')
# semi-necessary as a way to provide a mechanism for other portions of the program to access models
@cache
def load_model(name, device="cuda", **kwargs):
load_path = None
if "autoregressive" in name or "unified_voice" in name:
model = UnifiedVoice(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/autoregressive.pth'
elif "diffusion" in name:
model = DiffusionTTS(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth'
elif "clvp" in name:
model = CLVP(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/clvp2.pth'
elif "vocoder" in name:
model = UnivNetGenerator(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/vocoder.pth'
elif "dvae" in name:
load_path = f'{DEFAULT_MODEL_PATH}/dvae.pth'
model = DiscreteVAE(**kwargs)
# to-do: figure out of the below two give the exact same output, since the AR uses #1, the Diffusion uses #2
# to-do: figure out of the below two give the exact same output
elif "stft" in name:
model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs)
sr = kwargs.pop("sr")
if sr == 24_000:
model = TacotronSTFT(1024, 256, 1024, 100, 24000, 0, 12000, **kwargs)
else:
model = TacotronSTFT(**kwargs)
elif "tms" in name:
model = TorchMelSpectrogram(**kwargs)
model = model.to(device=device)
if load_path is not None:
model.load_state_dict(torch.load(load_path, map_location=device), strict=False)
return model
def unload_model():

View File

@ -289,7 +289,7 @@ class AudioMiniEncoder(nn.Module):
return h[:, :, 0]
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/mel_norms.pth')
DEFAULT_MEL_NORM_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../data/mel_norms.pth')
class TorchMelSpectrogram(nn.Module):
@ -463,6 +463,34 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x
TACOTRON_MEL_MAX = 2.3143386840820312
TACOTRON_MEL_MIN = -11.512925148010254
def denormalize_tacotron_mel(norm_mel):
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
def normalize_tacotron_mel(mel):
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
@ -566,9 +594,16 @@ class STFT(torch.nn.Module):
return reconstruction
class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
mel_fmax=8000.0):
def __init__(
self,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=80,
sampling_rate=22050,
mel_fmin=0.0,
mel_fmax=8000.0
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate

View File

@ -119,7 +119,7 @@ class DiscreteVAE(nn.Module):
positional_dims = 1, # 2
num_tokens = 8192, # 512
codebook_dim = 512,
num_layers = 3,
num_layers = 2, # 3
num_resnet_blocks = 3, # 0
hidden_dim = 512, # 64
channels = 80, # 3

View File

@ -2,10 +2,9 @@
from .config import cfg
from .data import create_train_val_dataloader
from .emb import qnt
from .emb import mel
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .data import fold_inputs, unfold_outputs
from .utils.distributed import is_global_leader
import auraloss
@ -34,7 +33,7 @@ def train_feeder(engine, batch):
text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True)
mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True)
wav_lengths = pad_sequence([ length for length in batch["wav_lengths"] ], batch_first = True)
wav_lengths = pad_sequence([ length for length in batch["wav_length"] ], batch_first = True)
engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths)

View File

@ -1 +1 @@
__version__ = "0.0.1-dev20240602082927"
__version__ = "0.0.1-dev20240617224834"