From 99be4874824d95cd6aba0357404c6f1a32c58965 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Jun 2024 14:49:24 -0500 Subject: [PATCH] backported old fork features (kv_cache (which looking back seems like a spook), ddim sampling, etc) --- README.md | 8 ++ tortoise_tts/__main__.py | 4 + tortoise_tts/inference.py | 6 +- tortoise_tts/models/arch_utils.py | 15 ++- tortoise_tts/models/diffusion.py | 15 ++- tortoise_tts/models/unified_voice.py | 131 ++++++++++++++++----------- tortoise_tts/webui.py | 3 + 7 files changed, 125 insertions(+), 57 deletions(-) diff --git a/README.md b/README.md index 8a5828d..62b52fc 100644 --- a/README.md +++ b/README.md @@ -37,10 +37,15 @@ For training a LoRA, uncomment the `loras` block in your training YAML. - [X] Reimplement original inferencing through TorToiSe (as done with `api.py`) - [ ] Reimplement candidate selection with the CLVP + - [ ] Reimplement redaction with the Wav2Vec2 - [X] Implement training support (without DLAS) - [X] Feature parity with the VALL-E training setup with preparing a dataset ahead of time - [ ] Automagic offloading to CPU for unused models (for training and inferencing) - [X] Automagic handling of the original weights into compatible weights +- [ ] Reimplement added features from my original fork: + - [ ] "Better" conditioning latents calculating + - [x] Use of KV-cache for the AR + - [x] Re-enable DDIM sampler - [ ] Extend the original inference routine with additional features: - [ ] non-float32 / mixed precision for the entire stack - [x] BitsAndBytes support @@ -48,10 +53,13 @@ For training a LoRA, uncomment the `loras` block in your training YAML. - [x] LoRAs - [x] Web UI - [ ] Feature parity with [ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) + - Although I feel a lot of its features are the wrong way to go about it. - [ ] Additional samplers for the autoregressive model - [ ] Additional samplers for the diffusion model - [ ] BigVGAN in place of the original vocoder - [ ] XFormers / flash_attention_2 for the autoregressive model + - Beyond HF's internal implementation of handling alternative attention + - Both the AR and diffusion models also do their own attention... - [ ] Some vector embedding store to find the "best" utterance to pick - [ ] Documentation diff --git a/tortoise_tts/__main__.py b/tortoise_tts/__main__.py index c48765e..a8d17c9 100755 --- a/tortoise_tts/__main__.py +++ b/tortoise_tts/__main__.py @@ -22,6 +22,8 @@ def main(): parser.add_argument("--length-penalty", type=float, default=0.0) parser.add_argument("--beam-width", type=int, default=0) + parser.add_argument("--diffusion-sampler", type=str, default="ddim") + parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--device", type=str, default=None) parser.add_argument("--amp", action="store_true") @@ -56,6 +58,8 @@ def main(): #repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width, + + diffusion_sampler=args.diffusion_sampler ) """ language=args.language, diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py index 11cef83..8b2842c 100755 --- a/tortoise_tts/inference.py +++ b/tortoise_tts/inference.py @@ -114,6 +114,9 @@ class TTS(): beam_width=1, #mirostat_tau=0, #mirostat_eta=0.1, + + diffusion_sampler="ddim", + out_path=None ): lines = text.split("\n") @@ -222,9 +225,10 @@ class TTS(): precomputed_embeddings = diffusion.timestep_independent(latents, diffusion_latents, output_seq_len, False) noise = torch.randn(output_shape, device=latents.device) * diffusion_temp - mel = diffuser.p_sample_loop( + mel = diffuser.sample_loop( diffusion, output_shape, + sampler=diffusion_sampler, noise=noise, model_kwargs={'precomputed_aligned_embeddings': precomputed_embeddings}, progress=True diff --git a/tortoise_tts/models/arch_utils.py b/tortoise_tts/models/arch_utils.py index 231cb51..4991cec 100644 --- a/tortoise_tts/models/arch_utils.py +++ b/tortoise_tts/models/arch_utils.py @@ -42,6 +42,16 @@ def normalization(channels): return GroupNorm32(groups, channels) +AVAILABLE_ATTENTIONS = ["mem_efficient", "math", "sdpa"] + +try: + from xformers.ops import LowerTriangularMask + from xformers.ops.fmha import memory_efficient_attention + + AVAILABLE_ATTENTIONS.append("xformers") +except Exception as e: + print("Error while importing `xformers`", e) + class QKVAttentionLegacy(nn.Module): """ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping @@ -51,13 +61,14 @@ class QKVAttentionLegacy(nn.Module): super().__init__() self.n_heads = n_heads - def forward(self, qkv, mask=None, rel_pos=None): + def forward(self, qkv, mask=None, rel_pos=None, mode="xformers"): """ Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ + bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) @@ -73,11 +84,11 @@ class QKVAttentionLegacy(nn.Module): # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. mask = mask.repeat(self.n_heads, 1).unsqueeze(1) weight = weight * mask + a = torch.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) - class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. diff --git a/tortoise_tts/models/diffusion.py b/tortoise_tts/models/diffusion.py index 12d6257..2b1926a 100644 --- a/tortoise_tts/models/diffusion.py +++ b/tortoise_tts/models/diffusion.py @@ -9,6 +9,8 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np +from tqdm.auto import tqdm + from torch import autocast from .arch_utils import normalization, AttentionBlock @@ -493,6 +495,16 @@ class GaussianDiffusion: ) return out + def sample_loop(self, *args, **kwargs): + # YUCK + sampler = kwargs.pop("sampler").lower() if "sampler" in kwargs else "ddim" + if sampler == 'p': + return self.p_sample_loop(*args, **kwargs) + if sampler == 'ddim': + return self.ddim_sample_loop(*args, **kwargs) + + raise RuntimeError(f"Sampler not implemented: {sampler}") + def p_sample( self, model, @@ -780,9 +792,6 @@ class GaussianDiffusion: indices = list(range(self.num_timesteps))[::-1] if progress: - # Lazy import so that we don't depend on tqdm. - from tqdm.auto import tqdm - indices = tqdm(indices, disable=not progress) for i in indices: diff --git a/tortoise_tts/models/unified_voice.py b/tortoise_tts/models/unified_voice.py index 7828ad4..72c55fb 100644 --- a/tortoise_tts/models/unified_voice.py +++ b/tortoise_tts/models/unified_voice.py @@ -11,6 +11,7 @@ from transformers.utils.model_parallel_utils import get_device_map, assert_devic from .arch_utils import AttentionBlock from transformers import LogitsWarper +from transformers import GPT2Config, GPT2Model AVAILABLE_ATTENTIONS = ["mem_efficient", "math"] @@ -81,15 +82,16 @@ class ResBlock(nn.Module): def forward(self, x): return F.relu(self.net(x) + x) - class GPT2InferenceModel(GPT2PreTrainedModel): - def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear): + def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=True): super().__init__(config) self.transformer = gpt self.text_pos_embedding = text_pos_emb self.embeddings = embeddings self.lm_head = nn.Sequential(norm, linear) + self.kv_cache = kv_cache + # Model parallel self.model_parallel = False self.device_map = None @@ -123,8 +125,11 @@ class GPT2InferenceModel(GPT2PreTrainedModel): self.cached_mel_emb = mel_emb def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) + + if not self.kv_cache: + past = None + # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) @@ -278,38 +283,6 @@ class LearnedPositionEmbeddings(nn.Module): def get_fixed_embedding(self, ind, dev): return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) - -def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, attention_implementation="eager"): - """ - GPT-2 implemented by the HuggingFace library. - """ - from transformers import GPT2Config, GPT2Model - gpt_config = GPT2Config( - vocab_size=256, # Unused. - n_positions=max_mel_seq_len+max_text_seq_len, - n_ctx=max_mel_seq_len+max_text_seq_len, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - use_cache=not checkpointing, - attention_implementation=attention_implementation - ) - gpt = GPT2Model(gpt_config) - - if checkpointing: - gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( - use_reentrant=False - )) - - # Override the built in positional embeddings - del gpt.wpe - gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) - # Built-in token embeddings are unused. - del gpt.wte - return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim),\ - None, None - - class MelEncoder(nn.Module): def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2): super().__init__() @@ -341,6 +314,7 @@ class UnifiedVoice(nn.Module): model_dim=1024, # 512 heads=16, # 8 max_text_tokens=402, # 120 + max_prompt_tokens=2, # XTTS2 uses 70 max_mel_tokens=604, # 250 max_conditioning_inputs=2, # 1 mel_length_compression=1024, @@ -392,17 +366,48 @@ class UnifiedVoice(nn.Module): self.heads = heads self.max_mel_tokens = max_mel_tokens self.max_text_tokens = max_text_tokens + self.max_prompt_tokens = max_prompt_tokens self.model_dim = model_dim self.max_conditioning_inputs = max_conditioning_inputs self.mel_length_compression = mel_length_compression self.conditioning_encoder = ConditioningEncoder(80, model_dim, num_attn_heads=heads) self.text_embedding = nn.Embedding(self.number_text_tokens*types+1, model_dim) + if use_mel_codes_as_input: self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim) else: self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1) - self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \ - build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens+2+self.max_conditioning_inputs, self.max_text_tokens+2, checkpointing, attention_implementation) + + max_mel_seq_len = self.max_mel_tokens+2+self.max_conditioning_inputs + max_text_seq_len = self.max_text_tokens+2 + + gpt_config = GPT2Config( + vocab_size=256, # Unused. + n_positions=max_mel_seq_len+max_text_seq_len, + n_ctx=max_mel_seq_len+max_text_seq_len, + n_embd=model_dim, + n_layer=layers, + n_head=heads, + use_cache=not checkpointing, + attention_implementation=attention_implementation + ) + self.gpt = GPT2Model(gpt_config) + + if checkpointing: + self.gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) + + del self.gpt.wpe + self.gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Override the built in positional embeddings + del self.gpt.wte + self.gpt.wte = None # Built-in token embeddings are unused. + + self.mel_pos_embedding = LearnedPositionEmbeddings(max_mel_seq_len, model_dim) + self.text_pos_embedding = LearnedPositionEmbeddings(max_text_seq_len, model_dim) + self.mel_layer_pos_embedding = None + self.text_layer_pos_embedding = None + if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True) @@ -421,6 +426,42 @@ class UnifiedVoice(nn.Module): for module in embeddings: module.weight.data.normal_(mean=0.0, std=.02) + def post_init_gpt2_config(self, kv_cache = True, use_deepspeed = False): + seq_length = self.max_mel_tokens + self.max_text_tokens + self.max_prompt_tokens + self.inference_model = GPT2InferenceModel( + GPT2Config( + vocab_size=self.max_mel_tokens, + n_positions=seq_length, + n_ctx=seq_length, + n_embd=self.model_dim, + n_layer=self.layers, + n_head=self.heads, + gradient_checkpointing=False, + use_cache=True, + attn_implementation=self.attention_implementation, + ), + self.gpt, + self.mel_pos_embedding, + self.mel_embedding, + self.final_norm, + self.mel_head, + kv_cache=True + ) + + # technically should already be done on the framework side, but my old fork had this here anyways + if use_deepspeed: + import deepspeed + self.ds_engine = deepspeed.init_inference( + model=self.inference_model, + mp_size=1, + replace_with_kernel_inject=True, + # dtype=torch.float32 + ) + self.inference_model = self.ds_engine.module + + self.inference_model.eval() + self.gpt.wte = self.mel_embedding + def build_aligned_inputs_and_targets(self, input, start_token, stop_token): inp = F.pad(input, (1,0), value=start_token) tar = F.pad(input, (0,1), value=stop_token) @@ -547,23 +588,11 @@ class UnifiedVoice(nn.Module): return loss_text.mean(), loss_mel.mean(), mel_logits def inference_speech(self, speech_conditioning_latent, text_inputs, input_tokens=None, num_return_sequences=1, - max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs): + max_generate_length=None, typical_sampling=False, typical_mass=.9, kv_cache=True, **hf_generate_kwargs): seq_length = self.max_mel_tokens + self.max_text_tokens + 2 if not hasattr(self, 'inference_model'): # TODO: Decouple gpt_config from this inference model. - gpt_config = GPT2Config( - vocab_size=self.max_mel_tokens, - n_positions=seq_length, - n_ctx=seq_length, - n_embd=self.model_dim, - n_layer=self.layers, - n_head=self.heads, - gradient_checkpointing=False, - use_cache=True, - attn_implementation=self.attention_implementation, - ) - self.inference_model = GPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head) - self.gpt.wte = self.mel_embedding + self.post_init_gpt2_config(kv_cache = kv_cache) text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token) text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token) diff --git a/tortoise_tts/webui.py b/tortoise_tts/webui.py index 33984db..868712d 100644 --- a/tortoise_tts/webui.py +++ b/tortoise_tts/webui.py @@ -96,6 +96,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"]) parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"]) parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"]) + parser.add_argument("--diffusion-sampler", type=str, default=kwargs["diffusion-sampler"]) """ parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"]) parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"]) @@ -125,6 +126,8 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): #repetition_penalty_decay=args.repetition_penalty_decay, length_penalty=args.length_penalty, beam_width=args.beam_width, + + diffusion_sampler=args.diffusion_sampler, ) wav = wav.squeeze(0).cpu().numpy()