Improvements to GptTts

This commit is contained in:
James Betker 2021-07-31 15:57:57 -06:00
parent 31ee9ae262
commit 0c9e75bc69
5 changed files with 100 additions and 37 deletions

View File

@ -1,27 +1,19 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import munchify
from tqdm import tqdm
from models.arch_util import ConvGnSilu
from models.gpt_voice.pixelshuffle_1d import PixelUnshuffle1D, PixelShuffle1D
from models.tacotron2 import hparams
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.tacotron2 import Postnet
from models.tacotron2.text import symbols
from models.gpt_voice.min_gpt import GPT, GPTConfig
from trainer.networks import register_model
# A Conv1d that masks out kernel elements ahead of the current location.
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kernel_mask = torch.ones_like(self.weight)
self.kernel_mask[:, :, -(self.kernel_size[0]//2):] = 0
def forward(self, input):
self.kernel_mask = self.kernel_mask.to(input.device)
return self._conv_forward(input, self.weight * self.kernel_mask, self.bias)
class GptTts(nn.Module):
def __init__(self):
super().__init__()
@ -36,24 +28,26 @@ class GptTts(nn.Module):
self.text_embedding = nn.Embedding(number_symbols, model_dim)
# Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias
# into the model which we cannot provide in inference.
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d),
ConvGnSilu(model_dim//2, model_dim, kernel_size=5, stride=2, convnd=CausalConv1d))
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
PixelUnshuffle1D(2),
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d))
# *_tags are additively applied to
self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
self.separator = nn.Parameter(torch.randn(1, 1, model_dim))
self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
self.text_preprocess_xformer = GPT(GPTConfig(max_symbols_per_phrase, n_layer=2, n_head=2, n_embd=model_dim))
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8))
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d),
nn.Upsample(scale_factor=2, mode='nearest'),
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d),
# No need for causal convolutions when kernel_size=1
nn.Conv1d(model_dim//2, 1, kernel_size=1))
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d),
nn.Upsample(scale_factor=2, mode='nearest'),
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d),
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=5, convnd=CausalConv1d),
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, activation=False, norm=False, convnd=nn.Conv1d))
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
PixelShuffle1D(2),
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim//2, 1, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
PixelShuffle1D(2),
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
#self.postnet = Postnet(munchify(hparams.create_hparams()))
def forward(self, text_inputs, mel_targets, output_lengths):
# Pad mel_targets to be a multiple of 2
@ -62,13 +56,14 @@ class GptTts(nn.Module):
mel_targets = F.pad(mel_targets, (0,1))
text_emb = self.text_embedding(text_inputs)
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
text_emb = text_emb + self.text_tags
mel_emb = self.mel_encoder(mel_targets).permute(0,2,1)
mel_emb = mel_emb + self.audio_tags
emb = torch.cat([text_emb,
self.separator.repeat(text_emb.shape[0],1,1),
mel_emb], dim=1)
enc = self.gpt(emb)
enc = self.gpt(emb, text_emb.shape[1])
mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1)
gates = self.gate_head(mel_portion).squeeze(1)
mel_pred = self.mel_head(mel_portion)
@ -82,6 +77,9 @@ class GptTts(nn.Module):
if padded:
mel_pred = mel_pred[:, :, :-1]
gates = gates[:, :-1]
#postnet_mel_pred = self.postnet(mel_pred)
#return mel_pred, postnet_mel_pred, gates
return mel_pred, gates
def test_guide(self, mel_guide, amount=50):
@ -95,15 +93,16 @@ class GptTts(nn.Module):
GATE_THRESHOLD = .95
text_emb = self.text_embedding(text_inputs)
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
text_emb = text_emb + self.text_tags
b,s,c = text_emb.shape
emb = torch.cat([text_emb,
self.separator.repeat(text_emb.shape[0],1,1)], dim=1)
self.separator.repeat(text_emb.shape[0],1,1),], dim=1)
#self.test_guide(mel_guide)], dim=1)
completed = torch.zeros((b,), device=text_inputs.device, dtype=torch.bool)
output = None
for i in tqdm(range(self.max_mel_frames)):
enc = self.gpt(emb)
enc = self.gpt(emb, text_emb.shape[1])
inferred = enc[:,s:,:].permute(0,2,1)
# Create output frames.
inferred_mel_frame = self.mel_head(inferred)[:,:,-MEL_HEAD_EXPANSION:]
@ -143,9 +142,10 @@ if __name__ == '__main__':
torch.randn(2,80,747),
torch.tensor([600,747]))
print(m.shape)
#print(p.shape)
print(g.shape)
o = gpt.infer(torch.randint(high=24, size=(2,60)))
print(o.shape)
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)

View File

@ -56,7 +56,7 @@ class CausalSelfAttention(nn.Module):
.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x, layer_past=None):
def forward(self, x, text_block_size):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
@ -66,10 +66,12 @@ class CausalSelfAttention(nn.Module):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = att.masked_fill(self.mask[:,:,:T,:T].logical_or(
F.pad(torch.ones((B,self.n_head,text_block_size,text_block_size), device=x.device), (0, T-text_block_size, 0, T-text_block_size))) == 0,
float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
@ -91,8 +93,8 @@ class Block(nn.Module):
nn.Dropout(config.resid_pdrop),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
def forward(self, x, text_block_size):
x = x + self.attn(self.ln1(x), text_block_size)
x = x + self.mlp(self.ln2(x))
return x
@ -171,13 +173,14 @@ class GPT(nn.Module):
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
return optimizer
def forward(self, embeddings):
def forward(self, embeddings, text_block_sizes):
b, t, c = embeddings.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
# forward the GPT model
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(embeddings + position_embeddings)
x = self.blocks(x)
for block in self.blocks:
x = block(x, text_block_sizes)
return x

View File

@ -0,0 +1,49 @@
import torch
# "long" and "short" denote longer and shorter samples
class PixelShuffle1D(torch.nn.Module):
"""
1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf
Upscales sample length, downscales channel length
"short" is input, "long" is output
"""
def __init__(self, upscale_factor):
super(PixelShuffle1D, self).__init__()
self.upscale_factor = upscale_factor
def forward(self, x):
batch_size = x.shape[0]
short_channel_len = x.shape[1]
short_width = x.shape[2]
long_channel_len = short_channel_len // self.upscale_factor
long_width = self.upscale_factor * short_width
x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width])
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(batch_size, long_channel_len, long_width)
return x
class PixelUnshuffle1D(torch.nn.Module):
"""
Inverse of 1D pixel shuffler
Upscales channel length, downscales sample length
"long" is input, "short" is output
"""
def __init__(self, downscale_factor):
super(PixelUnshuffle1D, self).__init__()
self.downscale_factor = downscale_factor
def forward(self, x):
batch_size = x.shape[0]
long_channel_len = x.shape[1]
long_width = x.shape[2]
short_channel_len = long_channel_len * self.downscale_factor
short_width = long_width // self.downscale_factor
x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor])
x = x.permute(0, 3, 1, 2).contiguous()
x = x.view([batch_size, short_channel_len, short_width])
return x

View File

@ -3,6 +3,8 @@ import logging
import random
import argparse
import torchvision
import utils
import utils.options as option
import utils.util as util
@ -19,11 +21,20 @@ def forward_pass(model, denoiser, data, output_dir, opt, b):
with torch.no_grad():
model.feed_data(data, 0)
model.test()
pred_waveforms = model.eval_state[opt['eval']['output_state']][0]
pred_waveforms = denoiser(pred_waveforms)
ground_truth_waveforms = model.eval_state[opt['eval']['ground_truth']][0]
ground_truth_waveforms = denoiser(ground_truth_waveforms)
for i in range(pred_waveforms.shape[0]):
# Output predicted mels and waveforms.
pred_mel = model.eval_state[opt['eval']['pred_mel']][i]
pred_mel = ((pred_mel - pred_mel.mean()) / max(abs(pred_mel.min()), pred_mel.max())).unsqueeze(1)
torchvision.utils.save_image(pred_mel, osp.join(output_dir, f'{b}_{i}_pred_mel.png'))
gt_mel = model.eval_state[opt['eval']['ground_truth_mel']][i]
gt_mel = ((gt_mel - gt_mel.mean()) / max(abs(gt_mel.min()), gt_mel.max())).unsqueeze(1)
torchvision.utils.save_image(gt_mel, osp.join(output_dir, f'{b}_{i}_gt_mel.png'))
audio = pred_waveforms[i][0].cpu().numpy()
wavfile.write(osp.join(output_dir, f'{b}_{i}.wav'), 22050, audio)
audio = ground_truth_waveforms[i][0].cpu().numpy()

View File

@ -54,7 +54,7 @@ class SGDNoBiasMomentum(Optimizer):
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay)
# **this is the only modification over standard torch.optim.SGD:
is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
is_bn_or_bias = (hasattr(p, 'is_norm') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
if not is_bn_or_bias and momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state: