Improvements to GptTts

This commit is contained in:
James Betker 2021-07-31 15:57:57 -06:00
parent 31ee9ae262
commit 0c9e75bc69
5 changed files with 100 additions and 37 deletions

View File

@ -1,27 +1,19 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from munch import munchify
from tqdm import tqdm from tqdm import tqdm
from models.arch_util import ConvGnSilu from models.arch_util import ConvGnSilu
from models.gpt_voice.pixelshuffle_1d import PixelUnshuffle1D, PixelShuffle1D
from models.tacotron2 import hparams
from models.tacotron2.taco_utils import get_mask_from_lengths from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.tacotron2 import Postnet
from models.tacotron2.text import symbols from models.tacotron2.text import symbols
from models.gpt_voice.min_gpt import GPT, GPTConfig from models.gpt_voice.min_gpt import GPT, GPTConfig
from trainer.networks import register_model from trainer.networks import register_model
# A Conv1d that masks out kernel elements ahead of the current location.
class CausalConv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kernel_mask = torch.ones_like(self.weight)
self.kernel_mask[:, :, -(self.kernel_size[0]//2):] = 0
def forward(self, input):
self.kernel_mask = self.kernel_mask.to(input.device)
return self._conv_forward(input, self.weight * self.kernel_mask, self.bias)
class GptTts(nn.Module): class GptTts(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -36,24 +28,26 @@ class GptTts(nn.Module):
self.text_embedding = nn.Embedding(number_symbols, model_dim) self.text_embedding = nn.Embedding(number_symbols, model_dim)
# Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias # Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias
# into the model which we cannot provide in inference. # into the model which we cannot provide in inference.
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d), self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim//2, model_dim, kernel_size=5, stride=2, convnd=CausalConv1d)) PixelUnshuffle1D(2),
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d))
# *_tags are additively applied to # *_tags are additively applied to
self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
self.separator = nn.Parameter(torch.randn(1, 1, model_dim)) self.separator = nn.Parameter(torch.randn(1, 1, model_dim))
self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
self.text_preprocess_xformer = GPT(GPTConfig(max_symbols_per_phrase, n_layer=2, n_head=2, n_embd=model_dim))
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8)) self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8))
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d), self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
nn.Upsample(scale_factor=2, mode='nearest'), PixelShuffle1D(2),
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d), ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
# No need for causal convolutions when kernel_size=1 ConvGnSilu(model_dim//2, 1, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
nn.Conv1d(model_dim//2, 1, kernel_size=1)) self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d), PixelShuffle1D(2),
nn.Upsample(scale_factor=2, mode='nearest'), ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d), ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=5, convnd=CausalConv1d), #self.postnet = Postnet(munchify(hparams.create_hparams()))
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, activation=False, norm=False, convnd=nn.Conv1d))
def forward(self, text_inputs, mel_targets, output_lengths): def forward(self, text_inputs, mel_targets, output_lengths):
# Pad mel_targets to be a multiple of 2 # Pad mel_targets to be a multiple of 2
@ -62,13 +56,14 @@ class GptTts(nn.Module):
mel_targets = F.pad(mel_targets, (0,1)) mel_targets = F.pad(mel_targets, (0,1))
text_emb = self.text_embedding(text_inputs) text_emb = self.text_embedding(text_inputs)
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
text_emb = text_emb + self.text_tags text_emb = text_emb + self.text_tags
mel_emb = self.mel_encoder(mel_targets).permute(0,2,1) mel_emb = self.mel_encoder(mel_targets).permute(0,2,1)
mel_emb = mel_emb + self.audio_tags mel_emb = mel_emb + self.audio_tags
emb = torch.cat([text_emb, emb = torch.cat([text_emb,
self.separator.repeat(text_emb.shape[0],1,1), self.separator.repeat(text_emb.shape[0],1,1),
mel_emb], dim=1) mel_emb], dim=1)
enc = self.gpt(emb) enc = self.gpt(emb, text_emb.shape[1])
mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1) mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1)
gates = self.gate_head(mel_portion).squeeze(1) gates = self.gate_head(mel_portion).squeeze(1)
mel_pred = self.mel_head(mel_portion) mel_pred = self.mel_head(mel_portion)
@ -82,6 +77,9 @@ class GptTts(nn.Module):
if padded: if padded:
mel_pred = mel_pred[:, :, :-1] mel_pred = mel_pred[:, :, :-1]
gates = gates[:, :-1] gates = gates[:, :-1]
#postnet_mel_pred = self.postnet(mel_pred)
#return mel_pred, postnet_mel_pred, gates
return mel_pred, gates return mel_pred, gates
def test_guide(self, mel_guide, amount=50): def test_guide(self, mel_guide, amount=50):
@ -95,15 +93,16 @@ class GptTts(nn.Module):
GATE_THRESHOLD = .95 GATE_THRESHOLD = .95
text_emb = self.text_embedding(text_inputs) text_emb = self.text_embedding(text_inputs)
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
text_emb = text_emb + self.text_tags text_emb = text_emb + self.text_tags
b,s,c = text_emb.shape b,s,c = text_emb.shape
emb = torch.cat([text_emb, emb = torch.cat([text_emb,
self.separator.repeat(text_emb.shape[0],1,1)], dim=1) self.separator.repeat(text_emb.shape[0],1,1),], dim=1)
#self.test_guide(mel_guide)], dim=1) #self.test_guide(mel_guide)], dim=1)
completed = torch.zeros((b,), device=text_inputs.device, dtype=torch.bool) completed = torch.zeros((b,), device=text_inputs.device, dtype=torch.bool)
output = None output = None
for i in tqdm(range(self.max_mel_frames)): for i in tqdm(range(self.max_mel_frames)):
enc = self.gpt(emb) enc = self.gpt(emb, text_emb.shape[1])
inferred = enc[:,s:,:].permute(0,2,1) inferred = enc[:,s:,:].permute(0,2,1)
# Create output frames. # Create output frames.
inferred_mel_frame = self.mel_head(inferred)[:,:,-MEL_HEAD_EXPANSION:] inferred_mel_frame = self.mel_head(inferred)[:,:,-MEL_HEAD_EXPANSION:]
@ -143,9 +142,10 @@ if __name__ == '__main__':
torch.randn(2,80,747), torch.randn(2,80,747),
torch.tensor([600,747])) torch.tensor([600,747]))
print(m.shape) print(m.shape)
#print(p.shape)
print(g.shape) print(g.shape)
o = gpt.infer(torch.randint(high=24, size=(2,60))) #o = gpt.infer(torch.randint(high=24, size=(2,60)))
print(o.shape) #print(o.shape)

View File

@ -56,7 +56,7 @@ class CausalSelfAttention(nn.Module):
.view(1, 1, config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head self.n_head = config.n_head
def forward(self, x, layer_past=None): def forward(self, x, text_block_size):
B, T, C = x.size() B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim # calculate query, key, values for all heads in batch and move head forward to be the batch dim
@ -66,7 +66,9 @@ class CausalSelfAttention(nn.Module):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) att = att.masked_fill(self.mask[:,:,:T,:T].logical_or(
F.pad(torch.ones((B,self.n_head,text_block_size,text_block_size), device=x.device), (0, T-text_block_size, 0, T-text_block_size))) == 0,
float('-inf'))
att = F.softmax(att, dim=-1) att = F.softmax(att, dim=-1)
att = self.attn_drop(att) att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
@ -91,8 +93,8 @@ class Block(nn.Module):
nn.Dropout(config.resid_pdrop), nn.Dropout(config.resid_pdrop),
) )
def forward(self, x): def forward(self, x, text_block_size):
x = x + self.attn(self.ln1(x)) x = x + self.attn(self.ln1(x), text_block_size)
x = x + self.mlp(self.ln2(x)) x = x + self.mlp(self.ln2(x))
return x return x
@ -171,13 +173,14 @@ class GPT(nn.Module):
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
return optimizer return optimizer
def forward(self, embeddings): def forward(self, embeddings, text_block_sizes):
b, t, c = embeddings.size() b, t, c = embeddings.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted." assert t <= self.block_size, "Cannot forward, model block size is exhausted."
# forward the GPT model # forward the GPT model
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(embeddings + position_embeddings) x = self.drop(embeddings + position_embeddings)
x = self.blocks(x) for block in self.blocks:
x = block(x, text_block_sizes)
return x return x

View File

@ -0,0 +1,49 @@
import torch
# "long" and "short" denote longer and shorter samples
class PixelShuffle1D(torch.nn.Module):
"""
1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf
Upscales sample length, downscales channel length
"short" is input, "long" is output
"""
def __init__(self, upscale_factor):
super(PixelShuffle1D, self).__init__()
self.upscale_factor = upscale_factor
def forward(self, x):
batch_size = x.shape[0]
short_channel_len = x.shape[1]
short_width = x.shape[2]
long_channel_len = short_channel_len // self.upscale_factor
long_width = self.upscale_factor * short_width
x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width])
x = x.permute(0, 2, 3, 1).contiguous()
x = x.view(batch_size, long_channel_len, long_width)
return x
class PixelUnshuffle1D(torch.nn.Module):
"""
Inverse of 1D pixel shuffler
Upscales channel length, downscales sample length
"long" is input, "short" is output
"""
def __init__(self, downscale_factor):
super(PixelUnshuffle1D, self).__init__()
self.downscale_factor = downscale_factor
def forward(self, x):
batch_size = x.shape[0]
long_channel_len = x.shape[1]
long_width = x.shape[2]
short_channel_len = long_channel_len * self.downscale_factor
short_width = long_width // self.downscale_factor
x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor])
x = x.permute(0, 3, 1, 2).contiguous()
x = x.view([batch_size, short_channel_len, short_width])
return x

View File

@ -3,6 +3,8 @@ import logging
import random import random
import argparse import argparse
import torchvision
import utils import utils
import utils.options as option import utils.options as option
import utils.util as util import utils.util as util
@ -19,11 +21,20 @@ def forward_pass(model, denoiser, data, output_dir, opt, b):
with torch.no_grad(): with torch.no_grad():
model.feed_data(data, 0) model.feed_data(data, 0)
model.test() model.test()
pred_waveforms = model.eval_state[opt['eval']['output_state']][0] pred_waveforms = model.eval_state[opt['eval']['output_state']][0]
pred_waveforms = denoiser(pred_waveforms) pred_waveforms = denoiser(pred_waveforms)
ground_truth_waveforms = model.eval_state[opt['eval']['ground_truth']][0] ground_truth_waveforms = model.eval_state[opt['eval']['ground_truth']][0]
ground_truth_waveforms = denoiser(ground_truth_waveforms) ground_truth_waveforms = denoiser(ground_truth_waveforms)
for i in range(pred_waveforms.shape[0]): for i in range(pred_waveforms.shape[0]):
# Output predicted mels and waveforms.
pred_mel = model.eval_state[opt['eval']['pred_mel']][i]
pred_mel = ((pred_mel - pred_mel.mean()) / max(abs(pred_mel.min()), pred_mel.max())).unsqueeze(1)
torchvision.utils.save_image(pred_mel, osp.join(output_dir, f'{b}_{i}_pred_mel.png'))
gt_mel = model.eval_state[opt['eval']['ground_truth_mel']][i]
gt_mel = ((gt_mel - gt_mel.mean()) / max(abs(gt_mel.min()), gt_mel.max())).unsqueeze(1)
torchvision.utils.save_image(gt_mel, osp.join(output_dir, f'{b}_{i}_gt_mel.png'))
audio = pred_waveforms[i][0].cpu().numpy() audio = pred_waveforms[i][0].cpu().numpy()
wavfile.write(osp.join(output_dir, f'{b}_{i}.wav'), 22050, audio) wavfile.write(osp.join(output_dir, f'{b}_{i}.wav'), 22050, audio)
audio = ground_truth_waveforms[i][0].cpu().numpy() audio = ground_truth_waveforms[i][0].cpu().numpy()

View File

@ -54,7 +54,7 @@ class SGDNoBiasMomentum(Optimizer):
if weight_decay != 0: if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay) d_p = d_p.add(p, alpha=weight_decay)
# **this is the only modification over standard torch.optim.SGD: # **this is the only modification over standard torch.optim.SGD:
is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias) is_bn_or_bias = (hasattr(p, 'is_norm') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
if not is_bn_or_bias and momentum != 0: if not is_bn_or_bias and momentum != 0:
param_state = self.state[p] param_state = self.state[p]
if 'momentum_buffer' not in param_state: if 'momentum_buffer' not in param_state: