tortoise-tts/tortoise_tts/emb/mel.py

196 lines
5.0 KiB
Python
Executable File

from ..config import cfg
import argparse
import random
import torch
import torchaudio
from functools import cache
from pathlib import Path
from typing import Union
from einops import rearrange
from torch import Tensor
from tqdm import tqdm
from ..models import load_model, unload_model
import torch.nn.functional as F
def pad_or_truncate(t, length):
"""
Utility function for forcing <t> to have the specified sequence length, whether by clipping it or padding it with 0s.
"""
if t.shape[-1] == length:
return t
elif t.shape[-1] < length:
return F.pad(t, (0, length-t.shape[-1]))
else:
return t[..., :length]
# decodes mel spectrogram into a wav
@torch.inference_mode()
def decode(codes: Tensor, device="cuda"):
model = load_model("vocoder", device)
return vocoder.inference(codes)
# huh
def decode_to_wave(resps: Tensor, device="cuda"):
return decode(resps, device=device, levels=levels)
def decode_to_file(resps: Tensor, path: Path, device="cuda"):
wavs, sr = decode(resps, device=device)
torchaudio.save(str(path), wavs.cpu(), sr)
return wavs, sr
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
def format_autoregressive_conditioning( wav, cond_length=132300, device="cuda" ):
"""
Converts the given conditioning signal to a MEL spectrogram and clips it as expected by the models.
"""
model = load_model("tms", device=device)
if cond_length > 0:
gap = wav.shape[-1] - cond_length
if gap < 0:
wav = F.pad(wav, pad=(0, abs(gap)))
elif gap > 0:
rand_start = random.randint(0, gap)
wav = wav[:, rand_start:rand_start + cond_length]
mel_clip = model(wav.unsqueeze(0)).squeeze(0) # ???
return mel_clip.unsqueeze(0).to(device) # ???
def format_diffusion_conditioning( sample, device, do_normalization=False ):
model = load_model("stft", device=device, sr=24_000)
sample = torchaudio.functional.resample(sample, 22050, 24000)
sample = pad_or_truncate(sample, 102400)
sample = sample.to(device)
mel = model.mel_spectrogram(sample)
"""
if do_normalization:
mel = normalize_tacotron_mel(mel)
"""
return mel
# encode a wav to conditioning latents + mel codes
@torch.inference_mode()
def encode(wav: Tensor, sr: int = cfg.sample_rate, device="cuda"):
dvae = load_model("dvae", device=device)
unified_voice = load_model("unified_voice", device=device)
diffusion = load_model("diffusion", device=device)
mel_inputs = format_autoregressive_conditioning( wav, 0, device )
wav_length = wav.shape[-1]
duration = wav_length / sr
autoregressive_conds = torch.stack([ format_autoregressive_conditioning(wav.to(device), device=device) ], dim=1)
diffusion_conds = torch.stack([ format_diffusion_conditioning(wav.to(device), device=device) ], dim=1)
codes = dvae.get_codebook_indices( mel_inputs )
autoregressive_latent = unified_voice.get_conditioning(autoregressive_conds)
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
return {
"codes": codes,
"conds": (autoregressive_conds, diffusion_conds),
"latent": (autoregressive_latent, diffusion_latent),
"metadata": {
"original_length": wav_length,
"sample_rate": sr,
"duration": duration
}
}
def encode_from_files(paths, device="cuda"):
tuples = [ torchaudio.load(str(path)) for path in paths ]
wavs = []
main_sr = tuples[0][1]
for wav, sr in tuples:
assert sr == main_sr, "Mismatching sample rates"
if wav.shape[0] == 2:
wav = wav[:1]
wavs.append(wav)
wav = torch.cat(wavs, dim=-1)
return encode(wav, sr, device)
def encode_from_file(path, device="cuda"):
if isinstance( path, list ):
return encode_from_files( path, device )
else:
path = str(path)
wav, sr = torchaudio.load(path)
if wav.shape[0] == 2:
wav = wav[:1]
qnt = encode(wav, sr, device)
return qnt
"""
Helper Functions
"""
# trims from the start, up to `target`
def trim( qnt, target ):
length = qnt.shape[0]
if target > 0:
start = 0
end = start + target
if end >= length:
start = length - target
end = length
# negative length specified, trim from end
else:
start = length + target
end = length
if start < 0:
start = 0
return qnt[start:end]
# trims a random piece of audio, up to `target`
# to-do: try and align to EnCodec window
def trim_random( qnt, target ):
length = qnt.shape[0]
start = int(length * random.random())
end = start + target
if end >= length:
start = length - target
end = length
return qnt[start:end]
# repeats the audio to fit the target size
def repeat_extend_audio( qnt, target ):
pieces = []
length = 0
while length < target:
pieces.append(qnt)
length += qnt.shape[0]
return trim(torch.cat(pieces), target)
# merges two quantized audios together
# I don't know if this works
def merge_audio( *args, device="cpu", scale=[] ):
qnts = [*args]
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]
if len(scale) == len(decoded):
for i in range(len(scale)):
decoded[i] = decoded[i] * scale[i]
combined = sum(decoded) / len(decoded)
return encode(combined, cfg.sample_rate, device="cpu", levels=levels)[0].t()