don't pad output mel tokens to speed up diffusion (despite copying it exactly from tortoise)
This commit is contained in:
parent
849de13f27
commit
5d24631bfb
|
@ -180,18 +180,22 @@ class TTS():
|
|||
repetition_penalty=repetition_penalty,
|
||||
max_generate_length=max_ar_steps,
|
||||
)
|
||||
|
||||
"""
|
||||
padding_needed = max_ar_steps - codes.shape[1]
|
||||
codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token)
|
||||
"""
|
||||
|
||||
for i, code in enumerate( codes ):
|
||||
stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero()
|
||||
stm = stop_token_indices.min().item()
|
||||
|
||||
if len(stop_token_indices) == 0:
|
||||
continue
|
||||
|
||||
codes[i][stop_token_indices] = 83
|
||||
stm = stop_token_indices.min().item()
|
||||
codes[i][stm:] = 83
|
||||
|
||||
if stm - 3 < codes[i].shape[0]:
|
||||
codes[i][-3] = 45
|
||||
codes[i][-2] = 45
|
||||
|
|
|
@ -52,6 +52,8 @@ try:
|
|||
except Exception as e:
|
||||
print("Error while importing `xformers`", e)
|
||||
|
||||
# from diffusers.models.attention_processing import AttnProcessor2_0
|
||||
# to-do: optimize this, as the diffuser *heavily* relies on this
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
||||
|
@ -61,7 +63,7 @@ class QKVAttentionLegacy(nn.Module):
|
|||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, mask=None, rel_pos=None, mode="xformers"):
|
||||
def forward(self, qkv, mask=None, rel_pos=None):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
|
@ -89,6 +91,46 @@ class QKVAttentionLegacy(nn.Module):
|
|||
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv, mask=None, rel_pos=None):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
if rel_pos is not None:
|
||||
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
if mask is not None:
|
||||
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
||||
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
||||
weight = weight * mask
|
||||
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
# actually sourced from https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py#L278
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
@ -102,12 +144,12 @@ class AttentionBlock(nn.Module):
|
|||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
do_checkpoint=True,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
relative_pos_embeddings=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.do_checkpoint = do_checkpoint
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
|
@ -115,10 +157,15 @@ class AttentionBlock(nn.Module):
|
|||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
if use_new_attention_order:
|
||||
# split qkv before split heads
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
|
||||
if relative_pos_embeddings:
|
||||
|
@ -126,10 +173,16 @@ class AttentionBlock(nn.Module):
|
|||
else:
|
||||
self.relative_pos_embeddings = None
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
def forward(self, x):
|
||||
if self.use_checkpoint:
|
||||
return checkpoint(self._forward, (x,), self.parameters(), True)
|
||||
return self._forward(x)
|
||||
|
||||
def _forward(self, x, mask=None):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
#h = self.attention(qkv)
|
||||
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
@ -479,29 +532,29 @@ TACOTRON_MEL_MIN = -11.512925148010254
|
|||
|
||||
|
||||
def denormalize_tacotron_mel(norm_mel):
|
||||
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
|
||||
return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
|
||||
|
||||
|
||||
def normalize_tacotron_mel(mel):
|
||||
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
||||
return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
|
||||
|
||||
|
||||
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor
|
||||
"""
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression(x, C=1):
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
"""
|
||||
PARAMS
|
||||
------
|
||||
C: compression factor used to compress
|
||||
"""
|
||||
return torch.exp(x) / C
|
||||
|
||||
class STFT(torch.nn.Module):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
|
|
|
@ -1438,11 +1438,11 @@ class DiffusionTTS(nn.Module):
|
|||
)
|
||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
|
||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False))
|
||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||
DiffusionLayer(model_channels, dropout, num_heads),
|
||||
|
|
Loading…
Reference in New Issue
Block a user