don't pad output mel tokens to speed up diffusion (despite copying it exactly from tortoise)

This commit is contained in:
mrq 2024-06-19 15:27:11 -05:00
parent 849de13f27
commit 5d24631bfb
3 changed files with 83 additions and 26 deletions

View File

@ -180,18 +180,22 @@ class TTS():
repetition_penalty=repetition_penalty,
max_generate_length=max_ar_steps,
)
"""
padding_needed = max_ar_steps - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token)
"""
for i, code in enumerate( codes ):
stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero()
stm = stop_token_indices.min().item()
if len(stop_token_indices) == 0:
continue
codes[i][stop_token_indices] = 83
stm = stop_token_indices.min().item()
codes[i][stm:] = 83
if stm - 3 < codes[i].shape[0]:
codes[i][-3] = 45
codes[i][-2] = 45

View File

@ -52,6 +52,8 @@ try:
except Exception as e:
print("Error while importing `xformers`", e)
# from diffusers.models.attention_processing import AttnProcessor2_0
# to-do: optimize this, as the diffuser *heavily* relies on this
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
@ -61,7 +63,7 @@ class QKVAttentionLegacy(nn.Module):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None, rel_pos=None, mode="xformers"):
def forward(self, qkv, mask=None, rel_pos=None):
"""
Apply QKV attention.
@ -89,6 +91,46 @@ class QKVAttentionLegacy(nn.Module):
return a.reshape(bs, -1, length)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None, rel_pos=None):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
if rel_pos is not None:
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
# actually sourced from https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py#L278
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
@ -102,12 +144,12 @@ class AttentionBlock(nn.Module):
channels,
num_heads=1,
num_head_channels=-1,
do_checkpoint=True,
use_checkpoint=False,
use_new_attention_order=False,
relative_pos_embeddings=False,
):
super().__init__()
self.channels = channels
self.do_checkpoint = do_checkpoint
if num_head_channels == -1:
self.num_heads = num_heads
else:
@ -115,10 +157,15 @@ class AttentionBlock(nn.Module):
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
if relative_pos_embeddings:
@ -126,10 +173,16 @@ class AttentionBlock(nn.Module):
else:
self.relative_pos_embeddings = None
def forward(self, x, mask=None):
def forward(self, x):
if self.use_checkpoint:
return checkpoint(self._forward, (x,), self.parameters(), True)
return self._forward(x)
def _forward(self, x, mask=None):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
#h = self.attention(qkv)
h = self.attention(qkv, mask, self.relative_pos_embeddings)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
@ -479,29 +532,29 @@ TACOTRON_MEL_MIN = -11.512925148010254
def denormalize_tacotron_mel(norm_mel):
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
def normalize_tacotron_mel(mel):
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""

View File

@ -1438,11 +1438,11 @@ class DiffusionTTS(nn.Module):
)
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False))
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
DiffusionLayer(model_channels, dropout, num_heads),