Add gpt_tts

This commit is contained in:
James Betker 2021-07-27 20:33:30 -06:00
parent 398185e109
commit dadc54795c
5 changed files with 274 additions and 12 deletions

View File

@ -3,11 +3,13 @@ import random
import numpy as np import numpy as np
import torch import torch
import torch.utils.data import torch.utils.data
from tqdm import tqdm
import models.tacotron2.layers as layers import models.tacotron2.layers as layers
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
from models.tacotron2.text import text_to_sequence from models.tacotron2.text import text_to_sequence
from utils.util import opt_get
class TextMelLoader(torch.utils.data.Dataset): class TextMelLoader(torch.utils.data.Dataset):
@ -23,8 +25,8 @@ class TextMelLoader(torch.utils.data.Dataset):
self.max_wav_value = hparams.max_wav_value self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate self.sampling_rate = hparams.sampling_rate
self.load_mel_from_disk = hparams.load_mel_from_disk self.load_mel_from_disk = hparams.load_mel_from_disk
self.return_wavs = hparams.return_wavs self.return_wavs = opt_get(hparams, ['return_wavs'], False)
self.input_sample_rate = hparams.input_sample_rate self.input_sample_rate = opt_get(hparams, ['input_sample_rate'], self.sampling_rate)
assert not (self.load_mel_from_disk and self.return_wavs) assert not (self.load_mel_from_disk and self.return_wavs)
self.stft = layers.TacotronSTFT( self.stft = layers.TacotronSTFT(
hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.filter_length, hparams.hop_length, hparams.win_length,
@ -134,10 +136,10 @@ if __name__ == '__main__':
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt', 'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
'phase': 'train', 'phase': 'train',
'n_workers': 0, 'n_workers': 0,
'batch_size': 2, 'batch_size': 16,
'return_wavs': True, #'return_wavs': True,
'input_sample_rate': 22050, #'input_sample_rate': 22050,
'sampling_rate': 8000 #'sampling_rate': 8000
} }
from data import create_dataset, create_dataloader from data import create_dataset, create_dataloader
@ -145,10 +147,10 @@ if __name__ == '__main__':
dl = create_dataloader(ds, params, collate_fn=c) dl = create_dataloader(ds, params, collate_fn=c)
i = 0 i = 0
m = [] m = []
for b in dl: max_text = 0
m.append(b) max_mel = 0
i += 1 for b in tqdm(dl):
if i > 9999: max_mel = max(max_mel, b['padded_mel'].shape[2])
break max_text = max(max_text, b['padded_text'].shape[1])
m=torch.stack(m) m=torch.stack(m)
print(m.mean(), m.std()) print(m.mean(), m.std())

View File

View File

@ -0,0 +1,77 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import ConvGnSilu
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols
from models.gpt_voice.min_gpt import GPT, GPTConfig
from trainer.networks import register_model
class GptTts(nn.Module):
def __init__(self):
super().__init__()
number_symbols = len(symbols)
model_dim = 512
max_symbols_per_phrase = 200
max_mel_frames = 900
mel_dim=80
self.text_embedding = nn.Embedding(number_symbols, model_dim)
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=3, convnd=nn.Conv1d),
ConvGnSilu(model_dim//2, model_dim, kernel_size=3, stride=2, convnd=nn.Conv1d))
self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
self.gpt = GPT(GPTConfig(max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8))
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=nn.Conv1d),
nn.Upsample(scale_factor=2, mode='nearest'),
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=nn.Conv1d),
nn.Conv1d(model_dim//2, 1, kernel_size=1))
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=nn.Conv1d),
nn.Upsample(scale_factor=2, mode='nearest'),
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=nn.Conv1d),
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=5, convnd=nn.Conv1d),
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, activation=False, norm=False, convnd=nn.Conv1d))
def forward(self, text_inputs, mel_targets, output_lengths):
# Pad mel_targets to be a multiple of 2
padded = mel_targets.shape[-1] % 2 != 0
if padded:
mel_targets = F.pad(mel_targets, (0,1))
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_tags
mel_emb = self.mel_encoder(mel_targets).permute(0,2,1)
mel_emb = mel_emb + self.audio_tags
emb = torch.cat([text_emb, mel_emb], dim=1)
enc = self.gpt(emb)
mel_portion = enc[:, text_emb.shape[1]:].permute(0,2,1)
gates = self.gate_head(mel_portion).squeeze(1)
mel_pred = self.mel_head(mel_portion)
# Mask portions of output which we don't need to predict.
mask = ~get_mask_from_lengths(output_lengths, mel_pred.shape[-1])
mask = mask.unsqueeze(1).repeat(1, mel_pred.shape[1], 1)
mel_pred.data.masked_fill_(mask, 0)
gates.data.masked_fill_(mask[:, 0, :], 1e3)
if padded:
mel_pred = mel_pred[:, :, :-1]
gates = gates[:, :-1]
return mel_pred, gates
@register_model
def register_gpt_tts(opt_net, opt):
return GptTts()
if __name__ == '__main__':
gpt = GptTts()
m, g = gpt(torch.randint(high=24, size=(2,60)),
torch.randn(2,80,747),
torch.tensor([600,747]))
print(m.shape)
print(g.shape)

View File

@ -0,0 +1,183 @@
"""
GPT model:
- the initial stem consists of a combination of token encoding and a positional encoding
- the meat of it is a uniform sequence of Transformer blocks
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
- all blocks feed into a central residual pathway similar to resnets
- the final decoder is a linear projection into a vanilla Softmax classifier
Original author: karpathy@, https://github.com/karpathy/minGPT
"""
import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
logger = logging.getLogger(__name__)
class GPTConfig:
""" base GPT config, params common to all GPT versions """
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
def __init__(self, block_size, n_layer=12, n_head=12, n_embd=768, **kwargs):
self.block_size = block_size
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
for k,v in kwargs.items():
setattr(self, k, v)
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
"""
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x, layer_past=None):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """
def __init__(self, config):
super().__init__()
# input embedding stem
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.embd_pdrop)
# transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
self.block_size = config.block_size
self.apply(self._init_weights)
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
return optimizer
def forward(self, embeddings):
b, t, c = embeddings.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
# forward the GPT model
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(embeddings + position_embeddings)
x = self.blocks(x)
return x

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wave_tacotron_diffusion_lj.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()