working, the vocoder was just loading wrong

This commit is contained in:
mrq 2024-06-18 20:55:50 -05:00
parent b5570f1b86
commit fb313d7ef4
10 changed files with 469 additions and 330 deletions

View File

@ -12,9 +12,9 @@ Simply run `pip install git+https://git.ecker.tech/mrq/tortoise-tts` or `pip ins
## To-Do
- [ ] Reimplement original inferencing through TorToiSe (as done with `api.py`)
- [ ] Implement training support (without DLAS)
- [ ] Feature parity with the VALL-E training setup with preparing a dataset ahead of time
- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`)
- [X] Implement training support (without DLAS)
- [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time
- [ ] Automagic handling of the original weights into compatible weights
- [ ] Extend the original inference routine with additional features:
- [x] non-float32 / mixed precision

View File

@ -10,53 +10,62 @@ def main():
parser = argparse.ArgumentParser("VALL-E TTS")
parser.add_argument("text")
parser.add_argument("references", type=path_list)
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second)
parser.add_argument("--max-nar-levels", type=int, default=7)
parser.add_argument("--max-ar-context", type=int, default=-1)
parser.add_argument("--max-ar-steps", type=int, default=500)
parser.add_argument("--max-diffusion-steps", type=int, default=80)
parser.add_argument("--ar-temp", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=0.01)
parser.add_argument("--min-ar-temp", type=float, default=-1.0)
parser.add_argument("--min-nar-temp", type=float, default=-1.0)
parser.add_argument("--input-prompt-length", type=float, default=3.0)
parser.add_argument("--diffusion-temp", type=float, default=0.01)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=16)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
#parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--beam-width", type=int, default=0)
parser.add_argument("--mirostat-tau", type=float, default=0)
parser.add_argument("--mirostat-eta", type=float, default=0)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=None)
"""
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--max-nar-levels", type=int, default=7)
parser.add_argument("--max-ar-context", type=int, default=-1)
#parser.add_argument("--min-ar-temp", type=float, default=-1.0)
#parser.add_argument("--min-nar-temp", type=float, default=-1.0)
#parser.add_argument("--input-prompt-length", type=float, default=3.0)
arser.add_argument("--mirostat-tau", type=float, default=0)
arser.add_argument("--mirostat-eta", type=float, default=0)
"""
args = parser.parse_args()
tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
tts.inference(
text=args.text,
references=args.references,
language=args.language,
out_path=args.out_path,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
max_ar_context=args.max_ar_context,
ar_temp=args.ar_temp, nar_temp=args.nar_temp,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
top_p=args.top_p, top_k=args.top_k,
repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
max_ar_steps=args.max_ar_steps,
max_diffusion_steps=args.max_diffusion_steps,
ar_temp=args.ar_temp,
diffusion_temp=args.diffusion_temp,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
#repetition_penalty_decay=args.repetition_penalty_decay,
length_penalty=args.length_penalty,
beam_width=args.beam_width,
)
"""
language=args.language,
input_prompt_length=args.input_prompt_length,
max_nar_levels=args.max_nar_levels,
max_ar_context=args.max_ar_context,
min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp,
mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta
)
"""
if __name__ == "__main__":
main()

View File

@ -17,6 +17,7 @@ from functools import cached_property
from pathlib import Path
from .utils.distributed import world_size
from .tokenizer import VoiceBpeTokenizer
# Yuck
from transformers import PreTrainedTokenizerFast
@ -496,177 +497,6 @@ class Inference:
return torch.float8_e4m3fn
return torch.float32
import inflect
import re
# Regular expression matching whitespace:
from unidecode import unidecode
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
text = text.replace('"', '')
return text
class VoiceBpeTokenizer:
def __init__(self, tokenizer_file=None):
if tokenizer_file is not None:
self.tokenizer = Tokenizer.from_file(tokenizer_file)
def preprocess_text(self, txt):
txt = english_cleaners(txt)
return txt
def encode(self, txt):
txt = self.preprocess_text(txt)
txt = txt.replace(' ', '[SPACE]')
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
txt = txt.replace('[SPACE]', ' ')
txt = txt.replace('[STOP]', '')
txt = txt.replace('[UNK]', '')
return txt
def get_vocab(self):
return self.tokenizer.get_vocab()
# should be renamed to optimizations
@dataclass()
class Optimizations:

View File

@ -499,19 +499,19 @@ class Dataset(_Dataset):
text = cfg.hdf5[key]["text"][:]
mel = cfg.hdf5[key]["audio"][:]
conds = (cfg.hdf5[key]["conds_0"][:], cfg.hdf5[key]["conds_1"][:])
#conds = (cfg.hdf5[key]["conds_0"][:], cfg.hdf5[key]["conds_1"][:])
latents = (cfg.hdf5[key]["latents_0"][:], cfg.hdf5[key]["latents_1"][:])
text = torch.from_numpy(text).to(self.text_dtype)
mel = torch.from_numpy(mel).to(torch.int16)
conds = (torch.from_numpy(conds[0]), torch.from_numpy(conds[1]))
#conds = (torch.from_numpy(conds[0]), torch.from_numpy(conds[1]))
latents = (torch.from_numpy(latents[0]), torch.from_numpy(latents[1]))
wav_length = cfg.hdf5[key].attrs["wav_length"]
else:
mel, metadata = _load_mels(path, return_metadata=True)
text = torch.tensor(metadata["text"]).to(self.text_dtype)
conds = (torch.from_numpy(metadata["conds"][0]), torch.from_numpy(metadata["conds"][1]))
#conds = (torch.from_numpy(metadata["conds"][0]), torch.from_numpy(metadata["conds"][1]))
latents = (torch.from_numpy(metadata["latent"][0]), torch.from_numpy(metadata["latent"][1]))
wav_length = metadata["wav_length"]
@ -524,8 +524,8 @@ class Dataset(_Dataset):
latents_0=latents[0][0],
latents_1=latents[1][0],
conds_0=conds[0][0, 0],
conds_1=conds[1][0, 0],
#conds_0=conds[0][0, 0],
#conds_1=conds[1][0, 0],
text=text,
mel=mel,
@ -782,9 +782,11 @@ def create_dataset_hdf5( skip_existing=True ):
if "audio" not in group:
group.create_dataset('audio', data=mel.numpy(), compression='lzf')
"""
for i, cond in enumerate(conds):
if f"conds_{i}" not in group:
group.create_dataset(f'conds_{i}', data=cond.numpy(), compression='lzf')
"""
for i, latent in enumerate(latents):
if f"latents_{i}" not in group:

View File

@ -6,14 +6,19 @@ from torch import Tensor
from einops import rearrange
from pathlib import Path
from .emb import g2p, qnt
from .emb.qnt import trim, trim_random
from .emb.mel import encode_from_files as encode_mel, trim, trim_random
from .utils import to_device
from .config import cfg
from .models import get_models
from .models import get_models, load_model
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize
from .data import get_phone_symmap, tokenize
from .models.arch_utils import denormalize_tacotron_mel
from .models.diffusion import get_diffuser
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
if deepspeed_available:
import deepspeed
@ -71,18 +76,10 @@ class TTS():
if isinstance( text, Tensor ):
return text
content = g2p.encode(text, language=language)
tokens = tokenize( content )
tokens = tokenize( text )
return torch.tensor( tokens )
def encode_lang( self, language ):
symmap = get_lang_symmap()
id = 0
if language in symmap:
id = symmap[language]
return torch.tensor([ id ])
def encode_audio( self, paths, trim_length=0.0 ):
# already a tensor, return it
if isinstance( paths, Tensor ):
@ -93,62 +90,151 @@ class TTS():
paths = [ Path(p) for p in paths.split(";") ]
# merge inputs
proms = []
for path in paths:
prom = qnt.encode_from_file(path)
if hasattr( prom, "codes" ):
prom = prom.codes
proms.append( prom )
res = torch.cat(proms)
if trim_length:
res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) )
return res
return encode_mel( paths, device=self.device )
@torch.inference_mode()
def inference(
self,
text,
references,
language="en",
max_ar_steps=6 * cfg.dataset.frames_per_second,
max_ar_context=-1,
input_prompt_length=0.0,
ar_temp=0.95,
diffusion_temp=0.5,
min_ar_temp=0.95,
min_diffusion_temp=0.5,
#language="en",
max_ar_steps=500,
max_diffusion_steps=80,
#max_ar_context=-1,
#input_prompt_length=0.0,
ar_temp=1.0,
diffusion_temp=1.0,
#min_ar_temp=0.95,
#min_diffusion_temp=0.5,
top_p=1.0,
top_k=0,
repetition_penalty=1.0,
repetition_penalty_decay=0.0,
#repetition_penalty_decay=0.0,
length_penalty=0.0,
beam_width=0,
mirostat_tau=0,
mirostat_eta=0.1,
beam_width=1,
#mirostat_tau=0,
#mirostat_eta=0.1,
out_path=None
):
lines = text.split("\n")
wavs = []
sr = None
sr = 24_000
autoregressive = None
diffusion = None
clvp = None
vocoder = None
diffuser = get_diffuser(steps=max_diffusion_steps, cond_free=False)
autoregressive_latents, diffusion_latents = self.encode_audio( references )["latent"]
for name, engine in self.engines.items():
...
if "autoregressive" in name:
autoregressive = engine.module
elif "diffusion" in name:
diffusion = engine.module
elif "clvp" in name:
clvp = engine.module
elif "vocoder" in name:
vocoder = engine.module
if autoregressive is None:
autoregressive = load_model("autoregressive", device=cfg.device)
if diffusion is None:
diffusion = load_model("diffusion", device=cfg.device)
if clvp is None:
clvp = load_model("clvp", device=cfg.device)
if vocoder is None:
vocoder = load_model("vocoder", device=cfg.device)
wavs = []
# other vars
calm_token = 832
for line in lines:
if out_path is None:
out_path = f"./data/{cfg.start_time}.wav"
...
text = self.encode_text( line ).to(device=cfg.device)
wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device)
wavs.append(wav)
text_tokens = pad_sequence([ text ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32)
with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
# autoregressive pass
codes = autoregressive.inference_speech(
autoregressive_latents,
text_tokens,
do_sample=True,
top_p=top_p,
temperature=ar_temp,
num_return_sequences=1,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
max_generate_length=max_ar_steps,
)
padding_needed = max_ar_steps - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token)
for i, code in enumerate( codes ):
stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero()
if len(stop_token_indices) == 0:
continue
codes[i][stop_token_indices] = 83
stm = stop_token_indices.min().item()
codes[i][stm:] = 83
if stm - 3 < codes[i].shape[0]:
codes[i][-3] = 45
codes[i][-2] = 45
codes[i][-1] = 248
wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device)
latents = autoregressive.forward(
autoregressive_latents,
text_tokens,
text_lengths,
codes,
wav_lengths,
return_latent=True,
clip_inputs=False
)
calm_tokens = 0
for k in range( codes.shape[-1] ):
if codes[0, k] == calm_token:
calm_tokens += 1
else:
calm_tokens = 0
if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech.
latents = latents[:, :k]
break
# diffusion pass
output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal.
output_shape = (latents.shape[0], 100, output_seq_len)
precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False)
noise = torch.randn(output_shape, device=latents.device) * diffusion_temp
mel = diffuser.p_sample_loop(
diffusion,
output_shape,
noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
progress=True
)
mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
# vocoder pass
waves = vocoder.inference(mels)
for wav in waves:
if out_path is not None:
torchaudio.save( out_path, wav.cpu(), sr )
wavs.append(wav)
return (torch.concat(wavs, dim=-1), sr)

View File

@ -10,6 +10,7 @@ from .diffusion import DiffusionTTS
from .vocoder import UnivNetGenerator
from .clvp import CLVP
from .dvae import DiscreteVAE
from .random_latent_generator import RandomLatentConverter
import os
import torch
@ -20,18 +21,30 @@ DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '
@cache
def load_model(name, device="cuda", **kwargs):
load_path = None
if "autoregressive" in name or "unified_voice" in name:
state_dict_key = None
strict = True
if "rlg" in name:
if "autoregressive" in name:
model = RandomLatentConverter(1024, **kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/rlg_auto.pth'
if "diffusion" in name:
model = RandomLatentConverter(2048, **kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/rlg_diffuser.pth'
elif "autoregressive" in name or "unified_voice" in name:
strict = False
model = UnifiedVoice(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/autoregressive.pth'
elif "diffusion" in name:
model = DiffusionTTS(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth'
load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth'
elif "clvp" in name:
model = CLVP(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/clvp2.pth'
elif "vocoder" in name:
model = UnivNetGenerator(**kwargs)
load_path = f'{DEFAULT_MODEL_PATH}/vocoder.pth'
state_dict_key = 'model_g'
elif "dvae" in name:
load_path = f'{DEFAULT_MODEL_PATH}/dvae.pth'
model = DiscreteVAE(**kwargs)
@ -44,11 +57,16 @@ def load_model(name, device="cuda", **kwargs):
model = TacotronSTFT(**kwargs)
elif "tms" in name:
model = TorchMelSpectrogram(**kwargs)
model = model.to(device=device)
if load_path is not None:
model.load_state_dict(torch.load(load_path, map_location=device), strict=False)
state_dict = torch.load(load_path, map_location=device)
if state_dict_key:
state_dict = state_dict[state_dict_key]
model.load_state_dict(state_dict, strict=strict)
model.eval()
return model

View File

@ -1565,6 +1565,21 @@ class DiffusionTTS(nn.Module):
return out, mel_pred
return out
def get_diffuser(
steps=80,
cond_free=True,
cond_free_k=2,
trained_diffusion_steps=4000,
):
return SpacedDiffusion(
use_timesteps=space_timesteps(trained_diffusion_steps, [steps]),
model_mean_type='epsilon',
model_var_type='learned_range',
loss_type='mse',
betas=get_named_beta_schedule('linear', trained_diffusion_steps),
conditioning_free=cond_free,
conditioning_free_k=cond_free_k
)
if __name__ == '__main__':
clip = torch.randn(2, 100, 400)

175
tortoise_tts/tokenizer.py Normal file
View File

@ -0,0 +1,175 @@
import os
import re
import inflect
import torch
from tokenizers import Tokenizer
# Regular expression matching whitespace:
from unidecode import unidecode
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
text = text.replace('"', '')
return text
class VoiceBpeTokenizer:
def __init__(self, tokenizer_file=None):
if tokenizer_file is not None:
self.tokenizer = Tokenizer.from_file(tokenizer_file)
def preprocess_text(self, txt):
txt = english_cleaners(txt)
return txt
def encode(self, txt):
txt = self.preprocess_text(txt)
txt = txt.replace(' ', '[SPACE]')
return self.tokenizer.encode(txt).ids
def decode(self, seq):
if isinstance(seq, torch.Tensor):
seq = seq.cpu().numpy()
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '')
txt = txt.replace('[SPACE]', ' ')
txt = txt.replace('[STOP]', '')
txt = txt.replace('[UNK]', '')
return txt
def get_vocab(self):
return self.tokenizer.get_vocab()

View File

@ -25,7 +25,8 @@ import argparse
from torch.nn.utils.rnn import pad_sequence
from .models.arch_utils import denormalize_tacotron_mel
from .models.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
from .models.diffusion import get_diffuser
from .models import load_model
_logger = logging.getLogger(__name__)
@ -36,15 +37,12 @@ def train_feeder(engine, batch):
device = batch["text"][0].device
batch_size = len(batch["text"])
autoregressive_conds = torch.stack([ conds for conds in batch["conds_0"] ])
diffusion_conds = torch.stack([ conds for conds in batch["conds_1"] ])
autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ])
diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ])
text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = engine.module.stop_mel_token )
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token )
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths)
@ -68,39 +66,11 @@ def run_eval(engines, eval_name, dl):
stats = defaultdict(list)
stats['loss'] = []
def process( name, batch, resps_list ):
for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]):
if len(hyp) == 0:
continue
filename = f'{speaker}_{path.parts[-1]}'
if task != "tts":
filename = f"{filename}_{task}"
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
ref_audio, sr = emb.decode_to_file(ref, ref_path)
hyp_audio, sr = emb.decode_to_file(hyp, hyp_path)
prom_audio, sr = emb.decode_to_file(prom, prom_path)
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
autoregressive = None
diffusion = None
clvp = None
vocoder = None
diffuser = get_diffuser(steps=30, cond_free=False)
for name in engines:
engine = engines[name]
@ -113,54 +83,44 @@ def run_eval(engines, eval_name, dl):
elif "vocoder" in name:
vocoder = engine.module
trained_diffusion_steps=4000
desired_diffusion_steps=50
cond_free=False
cond_free_k=1
diffuser = SpacedDiffusion(
use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]),
model_mean_type='epsilon',
model_var_type='learned_range',
loss_type='mse',
betas=get_named_beta_schedule('linear', trained_diffusion_steps),
conditioning_free=cond_free,
conditioning_free_k=cond_free_k
)
if autoregressive is None:
autoregressive = load_model("autoregressive", device=cfg.device)
if diffusion is None:
diffusion = load_model("diffusion", device=cfg.device)
if clvp is None:
clvp = load_model("clvp", device=cfg.device)
if vocoder is None:
vocoder = load_model("vocoder", device=cfg.device)
processed = 0
temperature = 1.0
while processed < cfg.evaluation.size:
batch: dict = to_device(next(iter(dl)), cfg.device)
processed += len(batch["text"])
max_mel_tokens = 500
def generate( batch, generate_codes=True ):
temperature = 1.0
max_mel_tokens = 500 # * autoregressive.mel_length_compression
stop_mel_token = autoregressive.stop_mel_token
calm_token = 83
verbose = True
verbose = False
autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ])
diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ])
text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token )
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
mel_codes = autoregressive.set_mel_padding(mel_codes, wav_lengths)
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
autoregressive_conds = torch.stack([ conds for conds in batch["conds_0"] ])
diffusion_conds = torch.stack([ conds for conds in batch["conds_1"] ])
autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ])
diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ])
text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token )
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
# autoregressive pass
if True:
if generate_codes:
codes = autoregressive.inference_speech(
autoregressive_latents,
text_tokens,
do_sample=True,
#top_p=top_p,
top_p=0.8,
temperature=temperature,
num_return_sequences=1,
#length_penalty=length_penalty,
#repetition_penalty=repetition_penalty,
length_penalty=1.0,
repetition_penalty=2.0,
max_generate_length=max_mel_tokens,
)
padding_needed = max_mel_tokens - codes.shape[1]
@ -168,6 +128,22 @@ def run_eval(engines, eval_name, dl):
else:
codes = mel_codes
for i, code in enumerate( codes ):
stop_token_indices = (codes[i] == stop_mel_token).nonzero()
if len(stop_token_indices) == 0:
continue
codes[i][stop_token_indices] = 83
stm = stop_token_indices.min().item()
codes[i][stm:] = 83
if stm - 3 < codes[i].shape[0]:
codes[i][-3] = 45
codes[i][-2] = 45
codes[i][-1] = 248
wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device)
latents = autoregressive.forward(
autoregressive_latents,
text_tokens,
@ -199,18 +175,47 @@ def run_eval(engines, eval_name, dl):
output_shape,
noise=noise,
model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings},
progress=verbose
progress=True
)
mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
# vocoder pass
wavs = vocoder.inference(mels)
for i, wav in enumerate( wavs ):
torchaudio.save( f"./data/{cfg.start_time}[{i}].wav", wav.cpu(), 24_000 )
return wavs
# process( name, batch, resps_list )
def process( name, batch, hyps, refs ):
for speaker, path, ref_audio, hyp_audio in zip(batch["spkr_name"], batch["path"], refs, hyps):
filename = f'{speaker}_{path.parts[-1]}'
# to-do, refine the output dir to be sane-er
ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav")
hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav")
prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav")
hyp_path.parent.mkdir(parents=True, exist_ok=True)
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save( hyp_path, hyp_audio.cpu(), 24_000 )
torchaudio.save( ref_path, ref_audio.cpu(), 24_000 )
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
processed = 0
while processed < cfg.evaluation.size:
batch = to_device(next(iter(dl)), cfg.device)
batch_size = len(batch["text"])
processed += batch_size
hyp = generate( batch, generate_codes=True )
ref = generate( batch, generate_codes=False )
process( name, batch, hyp, ref )
stats = {k: sum(v) / len(v) for k, v in stats.items()}
engines_stats = {
@ -254,7 +259,7 @@ def train():
do_gc()
#qnt.unload_model()
if args.eval:
return eval_fn(engines=trainer.load_engines())

View File

@ -1 +0,0 @@
__version__ = "0.0.1-dev20240617224834"