From fb313d7ef4ff0a65ad6826688298cd6c643c672a Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 18 Jun 2024 20:55:50 -0500 Subject: [PATCH] working, the vocoder was just loading wrong --- README.md | 6 +- tortoise_tts/__main__.py | 59 +++++----- tortoise_tts/config.py | 172 +---------------------------- tortoise_tts/data.py | 12 ++- tortoise_tts/inference.py | 178 +++++++++++++++++++++++-------- tortoise_tts/models/__init__.py | 26 ++++- tortoise_tts/models/diffusion.py | 15 +++ tortoise_tts/tokenizer.py | 175 ++++++++++++++++++++++++++++++ tortoise_tts/train.py | 155 ++++++++++++++------------- tortoise_tts/version.py | 1 - 10 files changed, 469 insertions(+), 330 deletions(-) create mode 100644 tortoise_tts/tokenizer.py delete mode 100644 tortoise_tts/version.py diff --git a/README.md b/README.md index 56c488a..df6fbbc 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,9 @@ Simply run `pip install git+https://git.ecker.tech/mrq/tortoise-tts` or `pip ins ## To-Do -- [ ] Reimplement original inferencing through TorToiSe (as done with `api.py`) -- [ ] Implement training support (without DLAS) - - [ ] Feature parity with the VALL-E training setup with preparing a dataset ahead of time +- [X] Reimplement original inferencing through TorToiSe (as done with `api.py`) +- [X] Implement training support (without DLAS) + - [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time - [ ] Automagic handling of the original weights into compatible weights - [ ] Extend the original inference routine with additional features: - [x] non-float32 / mixed precision diff --git a/tortoise_tts/__main__.py b/tortoise_tts/__main__.py index 5054411..c48765e 100755 --- a/tortoise_tts/__main__.py +++ b/tortoise_tts/__main__.py @@ -10,53 +10,62 @@ def main(): parser = argparse.ArgumentParser("VALL-E TTS") parser.add_argument("text") parser.add_argument("references", type=path_list) - parser.add_argument("--language", type=str, default="en") parser.add_argument("--out-path", type=Path, default=None) - - parser.add_argument("--yaml", type=Path, default=None) - - parser.add_argument("--max-ar-steps", type=int, default=12 * cfg.dataset.frames_per_second) - parser.add_argument("--max-nar-levels", type=int, default=7) - parser.add_argument("--max-ar-context", type=int, default=-1) - + parser.add_argument("--max-ar-steps", type=int, default=500) + parser.add_argument("--max-diffusion-steps", type=int, default=80) parser.add_argument("--ar-temp", type=float, default=1.0) - parser.add_argument("--nar-temp", type=float, default=0.01) - parser.add_argument("--min-ar-temp", type=float, default=-1.0) - parser.add_argument("--min-nar-temp", type=float, default=-1.0) - parser.add_argument("--input-prompt-length", type=float, default=3.0) - + parser.add_argument("--diffusion-temp", type=float, default=0.01) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=16) parser.add_argument("--repetition-penalty", type=float, default=1.0) - parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) + #parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--beam-width", type=int, default=0) - parser.add_argument("--mirostat-tau", type=float, default=0) - parser.add_argument("--mirostat-eta", type=float, default=0) - + parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--device", type=str, default=None) parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=None) + """ + parser.add_argument("--language", type=str, default="en") + parser.add_argument("--max-nar-levels", type=int, default=7) + parser.add_argument("--max-ar-context", type=int, default=-1) + + #parser.add_argument("--min-ar-temp", type=float, default=-1.0) + #parser.add_argument("--min-nar-temp", type=float, default=-1.0) + #parser.add_argument("--input-prompt-length", type=float, default=3.0) + + + arser.add_argument("--mirostat-tau", type=float, default=0) + arser.add_argument("--mirostat-eta", type=float, default=0) + """ args = parser.parse_args() tts = TTS( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp ) tts.inference( text=args.text, references=args.references, - language=args.language, out_path=args.out_path, - input_prompt_length=args.input_prompt_length, - max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, - max_ar_context=args.max_ar_context, - ar_temp=args.ar_temp, nar_temp=args.nar_temp, - min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, - top_p=args.top_p, top_k=args.top_k, - repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, + max_ar_steps=args.max_ar_steps, + max_diffusion_steps=args.max_diffusion_steps, + ar_temp=args.ar_temp, + diffusion_temp=args.diffusion_temp, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + #repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width, + ) + """ + language=args.language, + input_prompt_length=args.input_prompt_length, + max_nar_levels=args.max_nar_levels, + max_ar_context=args.max_ar_context, + min_ar_temp=args.min_ar_temp, min_nar_temp=args.min_nar_temp, mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta ) + """ if __name__ == "__main__": main() diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py index 29f179d..9ff4254 100755 --- a/tortoise_tts/config.py +++ b/tortoise_tts/config.py @@ -17,6 +17,7 @@ from functools import cached_property from pathlib import Path from .utils.distributed import world_size +from .tokenizer import VoiceBpeTokenizer # Yuck from transformers import PreTrainedTokenizerFast @@ -496,177 +497,6 @@ class Inference: return torch.float8_e4m3fn return torch.float32 -import inflect -import re - -# Regular expression matching whitespace: -from unidecode import unidecode - -_whitespace_re = re.compile(r'\s+') - -# List of (regular expression, replacement) pairs for abbreviations: -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), -]] - - -def expand_abbreviations(text): - for regex, replacement in _abbreviations: - text = re.sub(regex, replacement, text) - return text - - -_inflect = inflect.engine() -_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') -_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') -_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') -_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') -_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') -_number_re = re.compile(r'[0-9]+') - - -def _remove_commas(m): - return m.group(1).replace(',', '') - - -def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') - - -def _expand_dollars(m): - match = m.group(1) - parts = match.split('.') - if len(parts) > 2: - return match + ' dollars' # Unexpected format - dollars = int(parts[0]) if parts[0] else 0 - cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 - if dollars and cents: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) - elif dollars: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - return '%s %s' % (dollars, dollar_unit) - elif cents: - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s' % (cents, cent_unit) - else: - return 'zero dollars' - - -def _expand_ordinal(m): - return _inflect.number_to_words(m.group(0)) - - -def _expand_number(m): - num = int(m.group(0)) - if num > 1000 and num < 3000: - if num == 2000: - return 'two thousand' - elif num > 2000 and num < 2010: - return 'two thousand ' + _inflect.number_to_words(num % 100) - elif num % 100 == 0: - return _inflect.number_to_words(num // 100) + ' hundred' - else: - return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') - else: - return _inflect.number_to_words(num, andword='') - - -def normalize_numbers(text): - text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r'\1 pounds', text) - text = re.sub(_dollars_re, _expand_dollars, text) - text = re.sub(_decimal_number_re, _expand_decimal_point, text) - text = re.sub(_ordinal_re, _expand_ordinal, text) - text = re.sub(_number_re, _expand_number, text) - return text - - -def expand_numbers(text): - return normalize_numbers(text) - - -def lowercase(text): - return text.lower() - - -def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) - - -def convert_to_ascii(text): - return unidecode(text) - - -def basic_cleaners(text): - '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def transliteration_cleaners(text): - '''Pipeline for non-English text that transliterates to ASCII.''' - text = convert_to_ascii(text) - text = lowercase(text) - text = collapse_whitespace(text) - return text - - -def english_cleaners(text): - '''Pipeline for English text, including number and abbreviation expansion.''' - text = convert_to_ascii(text) - text = lowercase(text) - text = expand_numbers(text) - text = expand_abbreviations(text) - text = collapse_whitespace(text) - text = text.replace('"', '') - return text - -class VoiceBpeTokenizer: - def __init__(self, tokenizer_file=None): - if tokenizer_file is not None: - self.tokenizer = Tokenizer.from_file(tokenizer_file) - - def preprocess_text(self, txt): - txt = english_cleaners(txt) - return txt - - def encode(self, txt): - txt = self.preprocess_text(txt) - txt = txt.replace(' ', '[SPACE]') - return self.tokenizer.encode(txt).ids - - def decode(self, seq): - if isinstance(seq, torch.Tensor): - seq = seq.cpu().numpy() - txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '') - txt = txt.replace('[SPACE]', ' ') - txt = txt.replace('[STOP]', '') - txt = txt.replace('[UNK]', '') - return txt - - def get_vocab(self): - return self.tokenizer.get_vocab() - # should be renamed to optimizations @dataclass() class Optimizations: diff --git a/tortoise_tts/data.py b/tortoise_tts/data.py index 4b15614..130a3fa 100755 --- a/tortoise_tts/data.py +++ b/tortoise_tts/data.py @@ -499,19 +499,19 @@ class Dataset(_Dataset): text = cfg.hdf5[key]["text"][:] mel = cfg.hdf5[key]["audio"][:] - conds = (cfg.hdf5[key]["conds_0"][:], cfg.hdf5[key]["conds_1"][:]) + #conds = (cfg.hdf5[key]["conds_0"][:], cfg.hdf5[key]["conds_1"][:]) latents = (cfg.hdf5[key]["latents_0"][:], cfg.hdf5[key]["latents_1"][:]) text = torch.from_numpy(text).to(self.text_dtype) mel = torch.from_numpy(mel).to(torch.int16) - conds = (torch.from_numpy(conds[0]), torch.from_numpy(conds[1])) + #conds = (torch.from_numpy(conds[0]), torch.from_numpy(conds[1])) latents = (torch.from_numpy(latents[0]), torch.from_numpy(latents[1])) wav_length = cfg.hdf5[key].attrs["wav_length"] else: mel, metadata = _load_mels(path, return_metadata=True) text = torch.tensor(metadata["text"]).to(self.text_dtype) - conds = (torch.from_numpy(metadata["conds"][0]), torch.from_numpy(metadata["conds"][1])) + #conds = (torch.from_numpy(metadata["conds"][0]), torch.from_numpy(metadata["conds"][1])) latents = (torch.from_numpy(metadata["latent"][0]), torch.from_numpy(metadata["latent"][1])) wav_length = metadata["wav_length"] @@ -524,8 +524,8 @@ class Dataset(_Dataset): latents_0=latents[0][0], latents_1=latents[1][0], - conds_0=conds[0][0, 0], - conds_1=conds[1][0, 0], + #conds_0=conds[0][0, 0], + #conds_1=conds[1][0, 0], text=text, mel=mel, @@ -782,9 +782,11 @@ def create_dataset_hdf5( skip_existing=True ): if "audio" not in group: group.create_dataset('audio', data=mel.numpy(), compression='lzf') + """ for i, cond in enumerate(conds): if f"conds_{i}" not in group: group.create_dataset(f'conds_{i}', data=cond.numpy(), compression='lzf') + """ for i, latent in enumerate(latents): if f"latents_{i}" not in group: diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py index c58ff37..99c0be8 100755 --- a/tortoise_tts/inference.py +++ b/tortoise_tts/inference.py @@ -6,14 +6,19 @@ from torch import Tensor from einops import rearrange from pathlib import Path -from .emb import g2p, qnt -from .emb.qnt import trim, trim_random +from .emb.mel import encode_from_files as encode_mel, trim, trim_random from .utils import to_device from .config import cfg -from .models import get_models +from .models import get_models, load_model from .engines import load_engines, deepspeed_available -from .data import get_phone_symmap, get_lang_symmap, _load_quants, _cleanup_phones, tokenize +from .data import get_phone_symmap, tokenize + +from .models.arch_utils import denormalize_tacotron_mel +from .models.diffusion import get_diffuser + +from torch.nn.utils.rnn import pad_sequence +import torch.nn.functional as F if deepspeed_available: import deepspeed @@ -71,18 +76,10 @@ class TTS(): if isinstance( text, Tensor ): return text - content = g2p.encode(text, language=language) - tokens = tokenize( content ) + tokens = tokenize( text ) return torch.tensor( tokens ) - def encode_lang( self, language ): - symmap = get_lang_symmap() - id = 0 - if language in symmap: - id = symmap[language] - return torch.tensor([ id ]) - def encode_audio( self, paths, trim_length=0.0 ): # already a tensor, return it if isinstance( paths, Tensor ): @@ -93,62 +90,151 @@ class TTS(): paths = [ Path(p) for p in paths.split(";") ] # merge inputs - - proms = [] - - for path in paths: - prom = qnt.encode_from_file(path) - if hasattr( prom, "codes" ): - prom = prom.codes - - proms.append( prom ) - - res = torch.cat(proms) - - if trim_length: - res = trim( res, int( cfg.dataset.frames_per_second * trim_length ) ) - - return res + return encode_mel( paths, device=self.device ) @torch.inference_mode() def inference( self, text, references, - language="en", - max_ar_steps=6 * cfg.dataset.frames_per_second, - max_ar_context=-1, - input_prompt_length=0.0, - ar_temp=0.95, - diffusion_temp=0.5, - min_ar_temp=0.95, - min_diffusion_temp=0.5, + #language="en", + max_ar_steps=500, + max_diffusion_steps=80, + #max_ar_context=-1, + #input_prompt_length=0.0, + ar_temp=1.0, + diffusion_temp=1.0, + #min_ar_temp=0.95, + #min_diffusion_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, - repetition_penalty_decay=0.0, + #repetition_penalty_decay=0.0, length_penalty=0.0, - beam_width=0, - mirostat_tau=0, - mirostat_eta=0.1, + beam_width=1, + #mirostat_tau=0, + #mirostat_eta=0.1, out_path=None ): lines = text.split("\n") wavs = [] - sr = None + sr = 24_000 + autoregressive = None + diffusion = None + clvp = None + vocoder = None + diffuser = get_diffuser(steps=max_diffusion_steps, cond_free=False) + + autoregressive_latents, diffusion_latents = self.encode_audio( references )["latent"] + for name, engine in self.engines.items(): - ... + if "autoregressive" in name: + autoregressive = engine.module + elif "diffusion" in name: + diffusion = engine.module + elif "clvp" in name: + clvp = engine.module + elif "vocoder" in name: + vocoder = engine.module + + if autoregressive is None: + autoregressive = load_model("autoregressive", device=cfg.device) + if diffusion is None: + diffusion = load_model("diffusion", device=cfg.device) + if clvp is None: + clvp = load_model("clvp", device=cfg.device) + if vocoder is None: + vocoder = load_model("vocoder", device=cfg.device) + + wavs = [] + # other vars + calm_token = 832 for line in lines: if out_path is None: out_path = f"./data/{cfg.start_time}.wav" - ... + text = self.encode_text( line ).to(device=cfg.device) - wav, sr = qnt.decode_to_file(resps_list[0], out_path, device=self.device) - wavs.append(wav) + text_tokens = pad_sequence([ text ], batch_first = True) + text_lengths = torch.Tensor([ text.shape[0] ]).to(dtype=torch.int32) + + with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): + # autoregressive pass + codes = autoregressive.inference_speech( + autoregressive_latents, + text_tokens, + do_sample=True, + top_p=top_p, + temperature=ar_temp, + num_return_sequences=1, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + max_generate_length=max_ar_steps, + ) + padding_needed = max_ar_steps - codes.shape[1] + codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token) + + for i, code in enumerate( codes ): + stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero() + + if len(stop_token_indices) == 0: + continue + + codes[i][stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[i][stm:] = 83 + if stm - 3 < codes[i].shape[0]: + codes[i][-3] = 45 + codes[i][-2] = 45 + codes[i][-1] = 248 + + wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device) + + latents = autoregressive.forward( + autoregressive_latents, + text_tokens, + text_lengths, + codes, + wav_lengths, + return_latent=True, + clip_inputs=False + ) + + calm_tokens = 0 + for k in range( codes.shape[-1] ): + if codes[0, k] == calm_token: + calm_tokens += 1 + else: + calm_tokens = 0 + if calm_tokens > 8: # 8 tokens gives the diffusion model some "breathing room" to terminate speech. + latents = latents[:, :k] + break + + # diffusion pass + output_seq_len = latents.shape[1] * 4 * 24000 // 22050 # This diffusion model converts from 22kHz spectrogram codes to a 24kHz spectrogram signal. + output_shape = (latents.shape[0], 100, output_seq_len) + precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False) + + noise = torch.randn(output_shape, device=latents.device) * diffusion_temp + mel = diffuser.p_sample_loop( + diffusion, + output_shape, + noise=noise, + model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, + progress=True + ) + mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] + + # vocoder pass + waves = vocoder.inference(mels) + + for wav in waves: + if out_path is not None: + torchaudio.save( out_path, wav.cpu(), sr ) + wavs.append(wav) return (torch.concat(wavs, dim=-1), sr) diff --git a/tortoise_tts/models/__init__.py b/tortoise_tts/models/__init__.py index ab3cb93..d77014f 100755 --- a/tortoise_tts/models/__init__.py +++ b/tortoise_tts/models/__init__.py @@ -10,6 +10,7 @@ from .diffusion import DiffusionTTS from .vocoder import UnivNetGenerator from .clvp import CLVP from .dvae import DiscreteVAE +from .random_latent_generator import RandomLatentConverter import os import torch @@ -20,18 +21,30 @@ DEFAULT_MODEL_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), ' @cache def load_model(name, device="cuda", **kwargs): load_path = None - if "autoregressive" in name or "unified_voice" in name: + state_dict_key = None + strict = True + + if "rlg" in name: + if "autoregressive" in name: + model = RandomLatentConverter(1024, **kwargs) + load_path = f'{DEFAULT_MODEL_PATH}/rlg_auto.pth' + if "diffusion" in name: + model = RandomLatentConverter(2048, **kwargs) + load_path = f'{DEFAULT_MODEL_PATH}/rlg_diffuser.pth' + elif "autoregressive" in name or "unified_voice" in name: + strict = False model = UnifiedVoice(**kwargs) load_path = f'{DEFAULT_MODEL_PATH}/autoregressive.pth' elif "diffusion" in name: model = DiffusionTTS(**kwargs) - load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth' + load_path = f'{DEFAULT_MODEL_PATH}/diffusion.pth' elif "clvp" in name: model = CLVP(**kwargs) load_path = f'{DEFAULT_MODEL_PATH}/clvp2.pth' elif "vocoder" in name: model = UnivNetGenerator(**kwargs) load_path = f'{DEFAULT_MODEL_PATH}/vocoder.pth' + state_dict_key = 'model_g' elif "dvae" in name: load_path = f'{DEFAULT_MODEL_PATH}/dvae.pth' model = DiscreteVAE(**kwargs) @@ -44,11 +57,16 @@ def load_model(name, device="cuda", **kwargs): model = TacotronSTFT(**kwargs) elif "tms" in name: model = TorchMelSpectrogram(**kwargs) - + model = model.to(device=device) if load_path is not None: - model.load_state_dict(torch.load(load_path, map_location=device), strict=False) + state_dict = torch.load(load_path, map_location=device) + if state_dict_key: + state_dict = state_dict[state_dict_key] + model.load_state_dict(state_dict, strict=strict) + + model.eval() return model diff --git a/tortoise_tts/models/diffusion.py b/tortoise_tts/models/diffusion.py index 00b7737..12d6257 100644 --- a/tortoise_tts/models/diffusion.py +++ b/tortoise_tts/models/diffusion.py @@ -1565,6 +1565,21 @@ class DiffusionTTS(nn.Module): return out, mel_pred return out +def get_diffuser( + steps=80, + cond_free=True, + cond_free_k=2, + trained_diffusion_steps=4000, +): + return SpacedDiffusion( + use_timesteps=space_timesteps(trained_diffusion_steps, [steps]), + model_mean_type='epsilon', + model_var_type='learned_range', + loss_type='mse', + betas=get_named_beta_schedule('linear', trained_diffusion_steps), + conditioning_free=cond_free, + conditioning_free_k=cond_free_k + ) if __name__ == '__main__': clip = torch.randn(2, 100, 400) diff --git a/tortoise_tts/tokenizer.py b/tortoise_tts/tokenizer.py new file mode 100644 index 0000000..59c3645 --- /dev/null +++ b/tortoise_tts/tokenizer.py @@ -0,0 +1,175 @@ +import os +import re + +import inflect +import torch +from tokenizers import Tokenizer + + +# Regular expression matching whitespace: +from unidecode import unidecode + +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + text = text.replace('"', '') + return text + +class VoiceBpeTokenizer: + def __init__(self, tokenizer_file=None): + if tokenizer_file is not None: + self.tokenizer = Tokenizer.from_file(tokenizer_file) + + def preprocess_text(self, txt): + txt = english_cleaners(txt) + return txt + + def encode(self, txt): + txt = self.preprocess_text(txt) + txt = txt.replace(' ', '[SPACE]') + return self.tokenizer.encode(txt).ids + + def decode(self, seq): + if isinstance(seq, torch.Tensor): + seq = seq.cpu().numpy() + txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(' ', '') + txt = txt.replace('[SPACE]', ' ') + txt = txt.replace('[STOP]', '') + txt = txt.replace('[UNK]', '') + return txt + + def get_vocab(self): + return self.tokenizer.get_vocab() \ No newline at end of file diff --git a/tortoise_tts/train.py b/tortoise_tts/train.py index 7da8399..2b45ba0 100755 --- a/tortoise_tts/train.py +++ b/tortoise_tts/train.py @@ -25,7 +25,8 @@ import argparse from torch.nn.utils.rnn import pad_sequence from .models.arch_utils import denormalize_tacotron_mel -from .models.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule +from .models.diffusion import get_diffuser +from .models import load_model _logger = logging.getLogger(__name__) @@ -36,15 +37,12 @@ def train_feeder(engine, batch): device = batch["text"][0].device batch_size = len(batch["text"]) - autoregressive_conds = torch.stack([ conds for conds in batch["conds_0"] ]) - diffusion_conds = torch.stack([ conds for conds in batch["conds_1"] ]) - autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ]) diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ]) text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True) text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32) - mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = engine.module.stop_mel_token ) + mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token ) wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32) engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths) @@ -68,39 +66,11 @@ def run_eval(engines, eval_name, dl): stats = defaultdict(list) stats['loss'] = [] - def process( name, batch, resps_list ): - for speaker, path, ref, hyp, prom, task in zip(batch["spkr_name"], batch["path"], batch["resps"], resps_list, batch["proms"], batch["task"]): - if len(hyp) == 0: - continue - - filename = f'{speaker}_{path.parts[-1]}' - - if task != "tts": - filename = f"{filename}_{task}" - - # to-do, refine the output dir to be sane-er - ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") - hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav") - prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav") - - hyp_path.parent.mkdir(parents=True, exist_ok=True) - ref_path.parent.mkdir(parents=True, exist_ok=True) - prom_path.parent.mkdir(parents=True, exist_ok=True) - - ref_audio, sr = emb.decode_to_file(ref, ref_path) - hyp_audio, sr = emb.decode_to_file(hyp, hyp_path) - prom_audio, sr = emb.decode_to_file(prom, prom_path) - - # pseudo loss calculation since we don't get the logits during eval - min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) - ref_audio = ref_audio[..., 0:min_length] - hyp_audio = hyp_audio[..., 0:min_length] - stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item()) - autoregressive = None diffusion = None clvp = None vocoder = None + diffuser = get_diffuser(steps=30, cond_free=False) for name in engines: engine = engines[name] @@ -113,54 +83,44 @@ def run_eval(engines, eval_name, dl): elif "vocoder" in name: vocoder = engine.module - trained_diffusion_steps=4000 - desired_diffusion_steps=50 - cond_free=False - cond_free_k=1 - diffuser = SpacedDiffusion( - use_timesteps=space_timesteps(trained_diffusion_steps, [desired_diffusion_steps]), - model_mean_type='epsilon', - model_var_type='learned_range', - loss_type='mse', - betas=get_named_beta_schedule('linear', trained_diffusion_steps), - conditioning_free=cond_free, - conditioning_free_k=cond_free_k - ) + if autoregressive is None: + autoregressive = load_model("autoregressive", device=cfg.device) + if diffusion is None: + diffusion = load_model("diffusion", device=cfg.device) + if clvp is None: + clvp = load_model("clvp", device=cfg.device) + if vocoder is None: + vocoder = load_model("vocoder", device=cfg.device) - processed = 0 - temperature = 1.0 - while processed < cfg.evaluation.size: - batch: dict = to_device(next(iter(dl)), cfg.device) - processed += len(batch["text"]) - - max_mel_tokens = 500 + def generate( batch, generate_codes=True ): + temperature = 1.0 + max_mel_tokens = 500 # * autoregressive.mel_length_compression stop_mel_token = autoregressive.stop_mel_token calm_token = 83 - verbose = True + verbose = False + + autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ]) + diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ]) + + text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True) + text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32) + mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token ) + wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32) + + mel_codes = autoregressive.set_mel_padding(mel_codes, wav_lengths) with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp): - autoregressive_conds = torch.stack([ conds for conds in batch["conds_0"] ]) - diffusion_conds = torch.stack([ conds for conds in batch["conds_1"] ]) - - autoregressive_latents = torch.stack([ latents for latents in batch["latents_0"] ]) - diffusion_latents = torch.stack([ latents for latents in batch["latents_1"] ]) - - text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True) - text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32) - mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token ) - wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32) - # autoregressive pass - if True: + if generate_codes: codes = autoregressive.inference_speech( autoregressive_latents, text_tokens, do_sample=True, - #top_p=top_p, + top_p=0.8, temperature=temperature, num_return_sequences=1, - #length_penalty=length_penalty, - #repetition_penalty=repetition_penalty, + length_penalty=1.0, + repetition_penalty=2.0, max_generate_length=max_mel_tokens, ) padding_needed = max_mel_tokens - codes.shape[1] @@ -168,6 +128,22 @@ def run_eval(engines, eval_name, dl): else: codes = mel_codes + for i, code in enumerate( codes ): + stop_token_indices = (codes[i] == stop_mel_token).nonzero() + + if len(stop_token_indices) == 0: + continue + + codes[i][stop_token_indices] = 83 + stm = stop_token_indices.min().item() + codes[i][stm:] = 83 + if stm - 3 < codes[i].shape[0]: + codes[i][-3] = 45 + codes[i][-2] = 45 + codes[i][-1] = 248 + + wav_lengths = torch.tensor([codes.shape[-1] * autoregressive.mel_length_compression], device=text_tokens.device) + latents = autoregressive.forward( autoregressive_latents, text_tokens, @@ -199,18 +175,47 @@ def run_eval(engines, eval_name, dl): output_shape, noise=noise, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, - progress=verbose + progress=True ) mels = denormalize_tacotron_mel(mel)[:,:,:output_seq_len] # vocoder pass wavs = vocoder.inference(mels) - for i, wav in enumerate( wavs ): - torchaudio.save( f"./data/{cfg.start_time}[{i}].wav", wav.cpu(), 24_000 ) + return wavs - # process( name, batch, resps_list ) + def process( name, batch, hyps, refs ): + for speaker, path, ref_audio, hyp_audio in zip(batch["spkr_name"], batch["path"], refs, hyps): + filename = f'{speaker}_{path.parts[-1]}' + # to-do, refine the output dir to be sane-er + ref_path = (cfg.log_dir / str(engines.global_step) / "ref" / filename).with_suffix(".wav") + hyp_path = (cfg.log_dir / str(engines.global_step) / name / eval_name / filename).with_suffix(".wav") + prom_path = (cfg.log_dir / str(engines.global_step) / name / "prom" / filename).with_suffix(".wav") + + hyp_path.parent.mkdir(parents=True, exist_ok=True) + ref_path.parent.mkdir(parents=True, exist_ok=True) + prom_path.parent.mkdir(parents=True, exist_ok=True) + + torchaudio.save( hyp_path, hyp_audio.cpu(), 24_000 ) + torchaudio.save( ref_path, ref_audio.cpu(), 24_000 ) + + # pseudo loss calculation since we don't get the logits during eval + min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) + ref_audio = ref_audio[..., 0:min_length] + hyp_audio = hyp_audio[..., 0:min_length] + stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item()) + + processed = 0 + while processed < cfg.evaluation.size: + batch = to_device(next(iter(dl)), cfg.device) + batch_size = len(batch["text"]) + processed += batch_size + + hyp = generate( batch, generate_codes=True ) + ref = generate( batch, generate_codes=False ) + + process( name, batch, hyp, ref ) stats = {k: sum(v) / len(v) for k, v in stats.items()} engines_stats = { @@ -254,7 +259,7 @@ def train(): do_gc() #qnt.unload_model() - + if args.eval: return eval_fn(engines=trainer.load_engines()) diff --git a/tortoise_tts/version.py b/tortoise_tts/version.py deleted file mode 100644 index 6e7f03e..0000000 --- a/tortoise_tts/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.0.1-dev20240617224834"