Improvements to GptTts
This commit is contained in:
parent
31ee9ae262
commit
0c9e75bc69
|
@ -1,27 +1,19 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import munchify
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.arch_util import ConvGnSilu
|
||||
from models.gpt_voice.pixelshuffle_1d import PixelUnshuffle1D, PixelShuffle1D
|
||||
from models.tacotron2 import hparams
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from models.tacotron2.tacotron2 import Postnet
|
||||
from models.tacotron2.text import symbols
|
||||
from models.gpt_voice.min_gpt import GPT, GPTConfig
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
# A Conv1d that masks out kernel elements ahead of the current location.
|
||||
class CausalConv1d(nn.Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.kernel_mask = torch.ones_like(self.weight)
|
||||
self.kernel_mask[:, :, -(self.kernel_size[0]//2):] = 0
|
||||
|
||||
def forward(self, input):
|
||||
self.kernel_mask = self.kernel_mask.to(input.device)
|
||||
return self._conv_forward(input, self.weight * self.kernel_mask, self.bias)
|
||||
|
||||
|
||||
class GptTts(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -36,24 +28,26 @@ class GptTts(nn.Module):
|
|||
self.text_embedding = nn.Embedding(number_symbols, model_dim)
|
||||
# Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias
|
||||
# into the model which we cannot provide in inference.
|
||||
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d),
|
||||
ConvGnSilu(model_dim//2, model_dim, kernel_size=5, stride=2, convnd=CausalConv1d))
|
||||
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
|
||||
PixelUnshuffle1D(2),
|
||||
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
|
||||
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d))
|
||||
# *_tags are additively applied to
|
||||
self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
|
||||
self.separator = nn.Parameter(torch.randn(1, 1, model_dim))
|
||||
self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
|
||||
self.text_preprocess_xformer = GPT(GPTConfig(max_symbols_per_phrase, n_layer=2, n_head=2, n_embd=model_dim))
|
||||
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8))
|
||||
|
||||
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d),
|
||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d),
|
||||
# No need for causal convolutions when kernel_size=1
|
||||
nn.Conv1d(model_dim//2, 1, kernel_size=1))
|
||||
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d),
|
||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d),
|
||||
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=5, convnd=CausalConv1d),
|
||||
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, activation=False, norm=False, convnd=nn.Conv1d))
|
||||
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
|
||||
PixelShuffle1D(2),
|
||||
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
|
||||
ConvGnSilu(model_dim//2, 1, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
|
||||
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
|
||||
PixelShuffle1D(2),
|
||||
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
|
||||
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
|
||||
#self.postnet = Postnet(munchify(hparams.create_hparams()))
|
||||
|
||||
def forward(self, text_inputs, mel_targets, output_lengths):
|
||||
# Pad mel_targets to be a multiple of 2
|
||||
|
@ -62,13 +56,14 @@ class GptTts(nn.Module):
|
|||
mel_targets = F.pad(mel_targets, (0,1))
|
||||
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
|
||||
text_emb = text_emb + self.text_tags
|
||||
mel_emb = self.mel_encoder(mel_targets).permute(0,2,1)
|
||||
mel_emb = mel_emb + self.audio_tags
|
||||
emb = torch.cat([text_emb,
|
||||
self.separator.repeat(text_emb.shape[0],1,1),
|
||||
mel_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
enc = self.gpt(emb, text_emb.shape[1])
|
||||
mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1)
|
||||
gates = self.gate_head(mel_portion).squeeze(1)
|
||||
mel_pred = self.mel_head(mel_portion)
|
||||
|
@ -82,6 +77,9 @@ class GptTts(nn.Module):
|
|||
if padded:
|
||||
mel_pred = mel_pred[:, :, :-1]
|
||||
gates = gates[:, :-1]
|
||||
|
||||
#postnet_mel_pred = self.postnet(mel_pred)
|
||||
#return mel_pred, postnet_mel_pred, gates
|
||||
return mel_pred, gates
|
||||
|
||||
def test_guide(self, mel_guide, amount=50):
|
||||
|
@ -95,15 +93,16 @@ class GptTts(nn.Module):
|
|||
GATE_THRESHOLD = .95
|
||||
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
|
||||
text_emb = text_emb + self.text_tags
|
||||
b,s,c = text_emb.shape
|
||||
emb = torch.cat([text_emb,
|
||||
self.separator.repeat(text_emb.shape[0],1,1)], dim=1)
|
||||
self.separator.repeat(text_emb.shape[0],1,1),], dim=1)
|
||||
#self.test_guide(mel_guide)], dim=1)
|
||||
completed = torch.zeros((b,), device=text_inputs.device, dtype=torch.bool)
|
||||
output = None
|
||||
for i in tqdm(range(self.max_mel_frames)):
|
||||
enc = self.gpt(emb)
|
||||
enc = self.gpt(emb, text_emb.shape[1])
|
||||
inferred = enc[:,s:,:].permute(0,2,1)
|
||||
# Create output frames.
|
||||
inferred_mel_frame = self.mel_head(inferred)[:,:,-MEL_HEAD_EXPANSION:]
|
||||
|
@ -143,9 +142,10 @@ if __name__ == '__main__':
|
|||
torch.randn(2,80,747),
|
||||
torch.tensor([600,747]))
|
||||
print(m.shape)
|
||||
#print(p.shape)
|
||||
print(g.shape)
|
||||
|
||||
o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
||||
print(o.shape)
|
||||
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
||||
#print(o.shape)
|
||||
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ class CausalSelfAttention(nn.Module):
|
|||
.view(1, 1, config.block_size, config.block_size))
|
||||
self.n_head = config.n_head
|
||||
|
||||
def forward(self, x, layer_past=None):
|
||||
def forward(self, x, text_block_size):
|
||||
B, T, C = x.size()
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
|
@ -66,10 +66,12 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
||||
att = att.masked_fill(self.mask[:,:,:T,:T].logical_or(
|
||||
F.pad(torch.ones((B,self.n_head,text_block_size,text_block_size), device=x.device), (0, T-text_block_size, 0, T-text_block_size))) == 0,
|
||||
float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_drop(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
|
@ -91,8 +93,8 @@ class Block(nn.Module):
|
|||
nn.Dropout(config.resid_pdrop),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln1(x))
|
||||
def forward(self, x, text_block_size):
|
||||
x = x + self.attn(self.ln1(x), text_block_size)
|
||||
x = x + self.mlp(self.ln2(x))
|
||||
return x
|
||||
|
||||
|
@ -171,13 +173,14 @@ class GPT(nn.Module):
|
|||
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
||||
return optimizer
|
||||
|
||||
def forward(self, embeddings):
|
||||
def forward(self, embeddings, text_block_sizes):
|
||||
b, t, c = embeddings.size()
|
||||
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
||||
|
||||
# forward the GPT model
|
||||
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
||||
x = self.drop(embeddings + position_embeddings)
|
||||
x = self.blocks(x)
|
||||
for block in self.blocks:
|
||||
x = block(x, text_block_sizes)
|
||||
|
||||
return x
|
49
codes/models/gpt_voice/pixelshuffle_1d.py
Normal file
49
codes/models/gpt_voice/pixelshuffle_1d.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
import torch
|
||||
|
||||
# "long" and "short" denote longer and shorter samples
|
||||
class PixelShuffle1D(torch.nn.Module):
|
||||
"""
|
||||
1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf
|
||||
Upscales sample length, downscales channel length
|
||||
"short" is input, "long" is output
|
||||
"""
|
||||
def __init__(self, upscale_factor):
|
||||
super(PixelShuffle1D, self).__init__()
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
short_channel_len = x.shape[1]
|
||||
short_width = x.shape[2]
|
||||
|
||||
long_channel_len = short_channel_len // self.upscale_factor
|
||||
long_width = self.upscale_factor * short_width
|
||||
|
||||
x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width])
|
||||
x = x.permute(0, 2, 3, 1).contiguous()
|
||||
x = x.view(batch_size, long_channel_len, long_width)
|
||||
|
||||
return x
|
||||
|
||||
class PixelUnshuffle1D(torch.nn.Module):
|
||||
"""
|
||||
Inverse of 1D pixel shuffler
|
||||
Upscales channel length, downscales sample length
|
||||
"long" is input, "short" is output
|
||||
"""
|
||||
def __init__(self, downscale_factor):
|
||||
super(PixelUnshuffle1D, self).__init__()
|
||||
self.downscale_factor = downscale_factor
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.shape[0]
|
||||
long_channel_len = x.shape[1]
|
||||
long_width = x.shape[2]
|
||||
|
||||
short_channel_len = long_channel_len * self.downscale_factor
|
||||
short_width = long_width // self.downscale_factor
|
||||
|
||||
x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor])
|
||||
x = x.permute(0, 3, 1, 2).contiguous()
|
||||
x = x.view([batch_size, short_channel_len, short_width])
|
||||
return x
|
|
@ -3,6 +3,8 @@ import logging
|
|||
import random
|
||||
import argparse
|
||||
|
||||
import torchvision
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
|
@ -19,11 +21,20 @@ def forward_pass(model, denoiser, data, output_dir, opt, b):
|
|||
with torch.no_grad():
|
||||
model.feed_data(data, 0)
|
||||
model.test()
|
||||
|
||||
pred_waveforms = model.eval_state[opt['eval']['output_state']][0]
|
||||
pred_waveforms = denoiser(pred_waveforms)
|
||||
ground_truth_waveforms = model.eval_state[opt['eval']['ground_truth']][0]
|
||||
ground_truth_waveforms = denoiser(ground_truth_waveforms)
|
||||
for i in range(pred_waveforms.shape[0]):
|
||||
# Output predicted mels and waveforms.
|
||||
pred_mel = model.eval_state[opt['eval']['pred_mel']][i]
|
||||
pred_mel = ((pred_mel - pred_mel.mean()) / max(abs(pred_mel.min()), pred_mel.max())).unsqueeze(1)
|
||||
torchvision.utils.save_image(pred_mel, osp.join(output_dir, f'{b}_{i}_pred_mel.png'))
|
||||
gt_mel = model.eval_state[opt['eval']['ground_truth_mel']][i]
|
||||
gt_mel = ((gt_mel - gt_mel.mean()) / max(abs(gt_mel.min()), gt_mel.max())).unsqueeze(1)
|
||||
torchvision.utils.save_image(gt_mel, osp.join(output_dir, f'{b}_{i}_gt_mel.png'))
|
||||
|
||||
audio = pred_waveforms[i][0].cpu().numpy()
|
||||
wavfile.write(osp.join(output_dir, f'{b}_{i}.wav'), 22050, audio)
|
||||
audio = ground_truth_waveforms[i][0].cpu().numpy()
|
||||
|
|
|
@ -54,7 +54,7 @@ class SGDNoBiasMomentum(Optimizer):
|
|||
if weight_decay != 0:
|
||||
d_p = d_p.add(p, alpha=weight_decay)
|
||||
# **this is the only modification over standard torch.optim.SGD:
|
||||
is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
|
||||
is_bn_or_bias = (hasattr(p, 'is_norm') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias)
|
||||
if not is_bn_or_bias and momentum != 0:
|
||||
param_state = self.state[p]
|
||||
if 'momentum_buffer' not in param_state:
|
||||
|
|
Loading…
Reference in New Issue
Block a user