don't pad output mel tokens to speed up diffusion (despite copying it exactly from tortoise)
This commit is contained in:
parent
849de13f27
commit
5d24631bfb
|
@ -180,18 +180,22 @@ class TTS():
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
max_generate_length=max_ar_steps,
|
max_generate_length=max_ar_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
padding_needed = max_ar_steps - codes.shape[1]
|
padding_needed = max_ar_steps - codes.shape[1]
|
||||||
codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token)
|
codes = F.pad(codes, (0, padding_needed), value=autoregressive.stop_mel_token)
|
||||||
|
"""
|
||||||
|
|
||||||
for i, code in enumerate( codes ):
|
for i, code in enumerate( codes ):
|
||||||
stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero()
|
stop_token_indices = (codes[i] == autoregressive.stop_mel_token).nonzero()
|
||||||
|
stm = stop_token_indices.min().item()
|
||||||
|
|
||||||
if len(stop_token_indices) == 0:
|
if len(stop_token_indices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
codes[i][stop_token_indices] = 83
|
codes[i][stop_token_indices] = 83
|
||||||
stm = stop_token_indices.min().item()
|
|
||||||
codes[i][stm:] = 83
|
codes[i][stm:] = 83
|
||||||
|
|
||||||
if stm - 3 < codes[i].shape[0]:
|
if stm - 3 < codes[i].shape[0]:
|
||||||
codes[i][-3] = 45
|
codes[i][-3] = 45
|
||||||
codes[i][-2] = 45
|
codes[i][-2] = 45
|
||||||
|
|
|
@ -52,6 +52,8 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error while importing `xformers`", e)
|
print("Error while importing `xformers`", e)
|
||||||
|
|
||||||
|
# from diffusers.models.attention_processing import AttnProcessor2_0
|
||||||
|
# to-do: optimize this, as the diffuser *heavily* relies on this
|
||||||
class QKVAttentionLegacy(nn.Module):
|
class QKVAttentionLegacy(nn.Module):
|
||||||
"""
|
"""
|
||||||
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
|
||||||
|
@ -61,7 +63,7 @@ class QKVAttentionLegacy(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
|
||||||
def forward(self, qkv, mask=None, rel_pos=None, mode="xformers"):
|
def forward(self, qkv, mask=None, rel_pos=None):
|
||||||
"""
|
"""
|
||||||
Apply QKV attention.
|
Apply QKV attention.
|
||||||
|
|
||||||
|
@ -89,6 +91,46 @@ class QKVAttentionLegacy(nn.Module):
|
||||||
|
|
||||||
return a.reshape(bs, -1, length)
|
return a.reshape(bs, -1, length)
|
||||||
|
|
||||||
|
class QKVAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
A module which performs QKV attention and splits in a different order.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
def forward(self, qkv, mask=None, rel_pos=None):
|
||||||
|
"""
|
||||||
|
Apply QKV attention.
|
||||||
|
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||||
|
:return: an [N x (H * C) x T] tensor after attention.
|
||||||
|
"""
|
||||||
|
bs, width, length = qkv.shape
|
||||||
|
assert width % (3 * self.n_heads) == 0
|
||||||
|
ch = width // (3 * self.n_heads)
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||||
|
weight = torch.einsum(
|
||||||
|
"bct,bcs->bts",
|
||||||
|
(q * scale).view(bs * self.n_heads, ch, length),
|
||||||
|
(k * scale).view(bs * self.n_heads, ch, length),
|
||||||
|
) # More stable with f16 than dividing afterwards
|
||||||
|
if rel_pos is not None:
|
||||||
|
weight = rel_pos(weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
|
||||||
|
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
if mask is not None:
|
||||||
|
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
|
||||||
|
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
|
||||||
|
weight = weight * mask
|
||||||
|
a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||||
|
return a.reshape(bs, -1, length)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_flops(model, _x, y):
|
||||||
|
return count_flops_attn(model, _x, y)
|
||||||
|
|
||||||
|
# actually sourced from https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py#L278
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
An attention block that allows spatial positions to attend to each other.
|
An attention block that allows spatial positions to attend to each other.
|
||||||
|
@ -102,12 +144,12 @@ class AttentionBlock(nn.Module):
|
||||||
channels,
|
channels,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
do_checkpoint=True,
|
use_checkpoint=False,
|
||||||
|
use_new_attention_order=False,
|
||||||
relative_pos_embeddings=False,
|
relative_pos_embeddings=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.do_checkpoint = do_checkpoint
|
|
||||||
if num_head_channels == -1:
|
if num_head_channels == -1:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
else:
|
else:
|
||||||
|
@ -115,8 +157,13 @@ class AttentionBlock(nn.Module):
|
||||||
channels % num_head_channels == 0
|
channels % num_head_channels == 0
|
||||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||||
self.num_heads = channels // num_head_channels
|
self.num_heads = channels // num_head_channels
|
||||||
|
self.use_checkpoint = use_checkpoint
|
||||||
self.norm = normalization(channels)
|
self.norm = normalization(channels)
|
||||||
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
self.qkv = nn.Conv1d(channels, channels * 3, 1)
|
||||||
|
if use_new_attention_order:
|
||||||
|
# split qkv before split heads
|
||||||
|
self.attention = QKVAttention(self.num_heads)
|
||||||
|
else:
|
||||||
# split heads before split qkv
|
# split heads before split qkv
|
||||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||||
|
|
||||||
|
@ -126,10 +173,16 @@ class AttentionBlock(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.relative_pos_embeddings = None
|
self.relative_pos_embeddings = None
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
def forward(self, x):
|
||||||
|
if self.use_checkpoint:
|
||||||
|
return checkpoint(self._forward, (x,), self.parameters(), True)
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
def _forward(self, x, mask=None):
|
||||||
b, c, *spatial = x.shape
|
b, c, *spatial = x.shape
|
||||||
x = x.reshape(b, c, -1)
|
x = x.reshape(b, c, -1)
|
||||||
qkv = self.qkv(self.norm(x))
|
qkv = self.qkv(self.norm(x))
|
||||||
|
#h = self.attention(qkv)
|
||||||
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
h = self.attention(qkv, mask, self.relative_pos_embeddings)
|
||||||
h = self.proj_out(h)
|
h = self.proj_out(h)
|
||||||
return (x + h).reshape(b, c, *spatial)
|
return (x + h).reshape(b, c, *spatial)
|
||||||
|
|
|
@ -1438,11 +1438,11 @@ class DiffusionTTS(nn.Module):
|
||||||
)
|
)
|
||||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||||
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
||||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
|
||||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
|
||||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
|
||||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False),
|
||||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False))
|
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, use_checkpoint=False))
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,1))
|
||||||
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
self.conditioning_timestep_integrator = TimestepEmbedSequential(
|
||||||
DiffusionLayer(model_channels, dropout, num_heads),
|
DiffusionLayer(model_channels, dropout, num_heads),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user