From 5d24631bfb6f655d6b9dce17804ea4082e0c923f Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Jun 2024 15:27:11 -0500 Subject: [PATCH] don't pad output mel tokens to speed up diffusion (despite copying it exactly from tortoise) --- tortoise_tts/inference.py | 6 +- tortoise_tts/models/arch_utils.py | 93 ++++++++++++++++++++++++------- tortoise_tts/models/diffusion.py | 10 ++-- 3 files changed, 83 insertions(+), 26 deletions(-) diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py index 8b2842c..e6ca975 100755 --- a/tortoise_tts/inference.py +++ b/tortoise_tts/inference.py @@ -180,18 +180,22 @@ class TTS(): repetition_penalty=repetition_penalty, max_generate_length=max_ar_steps, ) + + """ padding_needed = max_ar_steps - codes.shape[1] codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token) + """ for i, code in enumerate( codes ): stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero() + stm = stop_token_indices.min().item() if len(stop_token_indices) == 0: continue codes[i][stop_token_indices] = 83 - stm = stop_token_indices.min().item() codes[i][stm:] = 83 + if stm - 3 < codes[i].shape[0]: codes[i][-3] = 45 codes[i][-2] = 45 diff --git a/tortoise_tts/models/arch_utils.py b/tortoise_tts/models/arch_utils.py index 4991cec..871f246 100644 --- a/tortoise_tts/models/arch_utils.py +++ b/tortoise_tts/models/arch_utils.py @@ -52,6 +52,8 @@ try: except Exception as e: print("Error while importing `xformers`", e) +# from diffusers.models.attention_processing import AttnProcessor2_0 +# to-do: optimize this, as the diffuser *heavily* relies on this class QKVAttentionLegacy(nn.Module): """ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping @@ -61,7 +63,7 @@ class QKVAttentionLegacy(nn.Module): super().__init__() self.n_heads = n_heads - def forward(self, qkv, mask=None, rel_pos=None, mode="xformers"): + def forward(self, qkv, mask=None, rel_pos=None): """ Apply QKV attention. @@ -89,6 +91,46 @@ class QKVAttentionLegacy(nn.Module): return a.reshape(bs, -1, length) +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, mask=None, rel_pos=None): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + if rel_pos is not None: + weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1]) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + if mask is not None: + # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs. + mask = mask.repeat(self.n_heads, 1).unsqueeze(1) + weight = weight * mask + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + +# actually sourced from https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py#L278 class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. @@ -102,12 +144,12 @@ class AttentionBlock(nn.Module): channels, num_heads=1, num_head_channels=-1, - do_checkpoint=True, + use_checkpoint=False, + use_new_attention_order=False, relative_pos_embeddings=False, ): super().__init__() self.channels = channels - self.do_checkpoint = do_checkpoint if num_head_channels == -1: self.num_heads = num_heads else: @@ -115,10 +157,15 @@ class AttentionBlock(nn.Module): channels % num_head_channels == 0 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint self.norm = normalization(channels) self.qkv = nn.Conv1d(channels, channels * 3, 1) - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) if relative_pos_embeddings: @@ -126,10 +173,16 @@ class AttentionBlock(nn.Module): else: self.relative_pos_embeddings = None - def forward(self, x, mask=None): + def forward(self, x): + if self.use_checkpoint: + return checkpoint(self._forward, (x,), self.parameters(), True) + return self._forward(x) + + def _forward(self, x, mask=None): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) + #h = self.attention(qkv) h = self.attention(qkv, mask, self.relative_pos_embeddings) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) @@ -479,29 +532,29 @@ TACOTRON_MEL_MIN = -11.512925148010254 def denormalize_tacotron_mel(norm_mel): - return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN + return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN def normalize_tacotron_mel(mel): - return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 + return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 def dynamic_range_compression(x, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return torch.log(torch.clamp(x, min=clip_val) * C) + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) def dynamic_range_decompression(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C class STFT(torch.nn.Module): """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" diff --git a/tortoise_tts/models/diffusion.py b/tortoise_tts/models/diffusion.py index 1d65270..00c7c1e 100644 --- a/tortoise_tts/models/diffusion.py +++ b/tortoise_tts/models/diffusion.py @@ -1438,11 +1438,11 @@ class DiffusionTTS(nn.Module): ) self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2), nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False), - AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False)) + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False), + AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1)) self.conditioning_timestep_integrator = TimestepEmbedSequential( DiffusionLayer(model_channels, dropout, num_heads),