From dadc54795c2930b83128fb77f93dcf835e380624 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 27 Jul 2021 20:33:30 -0600 Subject: [PATCH] Add gpt_tts --- codes/data/audio/nv_tacotron_dataset.py | 24 ++-- codes/models/gpt_voice/__init__.py | 0 codes/models/gpt_voice/gpt_tts.py | 77 ++++++++++ codes/models/gpt_voice/min_gpt.py | 183 ++++++++++++++++++++++++ codes/train.py | 2 +- 5 files changed, 274 insertions(+), 12 deletions(-) create mode 100644 codes/models/gpt_voice/__init__.py create mode 100644 codes/models/gpt_voice/gpt_tts.py create mode 100644 codes/models/gpt_voice/min_gpt.py diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index 10d5eae1..4950839c 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -3,11 +3,13 @@ import random import numpy as np import torch import torch.utils.data +from tqdm import tqdm import models.tacotron2.layers as layers from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text from models.tacotron2.text import text_to_sequence +from utils.util import opt_get class TextMelLoader(torch.utils.data.Dataset): @@ -23,8 +25,8 @@ class TextMelLoader(torch.utils.data.Dataset): self.max_wav_value = hparams.max_wav_value self.sampling_rate = hparams.sampling_rate self.load_mel_from_disk = hparams.load_mel_from_disk - self.return_wavs = hparams.return_wavs - self.input_sample_rate = hparams.input_sample_rate + self.return_wavs = opt_get(hparams, ['return_wavs'], False) + self.input_sample_rate = opt_get(hparams, ['input_sample_rate'], self.sampling_rate) assert not (self.load_mel_from_disk and self.return_wavs) self.stft = layers.TacotronSTFT( hparams.filter_length, hparams.hop_length, hparams.win_length, @@ -134,10 +136,10 @@ if __name__ == '__main__': 'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt', 'phase': 'train', 'n_workers': 0, - 'batch_size': 2, - 'return_wavs': True, - 'input_sample_rate': 22050, - 'sampling_rate': 8000 + 'batch_size': 16, + #'return_wavs': True, + #'input_sample_rate': 22050, + #'sampling_rate': 8000 } from data import create_dataset, create_dataloader @@ -145,10 +147,10 @@ if __name__ == '__main__': dl = create_dataloader(ds, params, collate_fn=c) i = 0 m = [] - for b in dl: - m.append(b) - i += 1 - if i > 9999: - break + max_text = 0 + max_mel = 0 + for b in tqdm(dl): + max_mel = max(max_mel, b['padded_mel'].shape[2]) + max_text = max(max_text, b['padded_text'].shape[1]) m=torch.stack(m) print(m.mean(), m.std()) diff --git a/codes/models/gpt_voice/__init__.py b/codes/models/gpt_voice/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py new file mode 100644 index 00000000..a8648e8f --- /dev/null +++ b/codes/models/gpt_voice/gpt_tts.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.arch_util import ConvGnSilu +from models.tacotron2.taco_utils import get_mask_from_lengths +from models.tacotron2.text import symbols +from models.gpt_voice.min_gpt import GPT, GPTConfig +from trainer.networks import register_model + + +class GptTts(nn.Module): + def __init__(self): + super().__init__() + number_symbols = len(symbols) + model_dim = 512 + max_symbols_per_phrase = 200 + max_mel_frames = 900 + mel_dim=80 + + self.text_embedding = nn.Embedding(number_symbols, model_dim) + self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=3, convnd=nn.Conv1d), + ConvGnSilu(model_dim//2, model_dim, kernel_size=3, stride=2, convnd=nn.Conv1d)) + self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) + self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) + self.gpt = GPT(GPTConfig(max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8)) + + self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=nn.Conv1d), + nn.Upsample(scale_factor=2, mode='nearest'), + ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=nn.Conv1d), + nn.Conv1d(model_dim//2, 1, kernel_size=1)) + self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=nn.Conv1d), + nn.Upsample(scale_factor=2, mode='nearest'), + ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=nn.Conv1d), + ConvGnSilu(model_dim//2, model_dim//2, kernel_size=5, convnd=nn.Conv1d), + ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, activation=False, norm=False, convnd=nn.Conv1d)) + + def forward(self, text_inputs, mel_targets, output_lengths): + # Pad mel_targets to be a multiple of 2 + padded = mel_targets.shape[-1] % 2 != 0 + if padded: + mel_targets = F.pad(mel_targets, (0,1)) + + text_emb = self.text_embedding(text_inputs) + text_emb = text_emb + self.text_tags + mel_emb = self.mel_encoder(mel_targets).permute(0,2,1) + mel_emb = mel_emb + self.audio_tags + emb = torch.cat([text_emb, mel_emb], dim=1) + enc = self.gpt(emb) + mel_portion = enc[:, text_emb.shape[1]:].permute(0,2,1) + gates = self.gate_head(mel_portion).squeeze(1) + mel_pred = self.mel_head(mel_portion) + + # Mask portions of output which we don't need to predict. + mask = ~get_mask_from_lengths(output_lengths, mel_pred.shape[-1]) + mask = mask.unsqueeze(1).repeat(1, mel_pred.shape[1], 1) + mel_pred.data.masked_fill_(mask, 0) + gates.data.masked_fill_(mask[:, 0, :], 1e3) + + if padded: + mel_pred = mel_pred[:, :, :-1] + gates = gates[:, :-1] + return mel_pred, gates + + +@register_model +def register_gpt_tts(opt_net, opt): + return GptTts() + + +if __name__ == '__main__': + gpt = GptTts() + m, g = gpt(torch.randint(high=24, size=(2,60)), + torch.randn(2,80,747), + torch.tensor([600,747])) + print(m.shape) + print(g.shape) \ No newline at end of file diff --git a/codes/models/gpt_voice/min_gpt.py b/codes/models/gpt_voice/min_gpt.py new file mode 100644 index 00000000..19a5189b --- /dev/null +++ b/codes/models/gpt_voice/min_gpt.py @@ -0,0 +1,183 @@ +""" +GPT model: +- the initial stem consists of a combination of token encoding and a positional encoding +- the meat of it is a uniform sequence of Transformer blocks + - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block + - all blocks feed into a central residual pathway similar to resnets +- the final decoder is a linear projection into a vanilla Softmax classifier + +Original author: karpathy@, https://github.com/karpathy/minGPT +""" + +import math +import logging + +import torch +import torch.nn as nn +from torch.nn import functional as F + +logger = logging.getLogger(__name__) + +class GPTConfig: + """ base GPT config, params common to all GPT versions """ + embd_pdrop = 0.1 + resid_pdrop = 0.1 + attn_pdrop = 0.1 + + def __init__(self, block_size, n_layer=12, n_head=12, n_embd=768, **kwargs): + self.block_size = block_size + self.n_layer = n_layer + self.n_head = n_head + self.n_embd = n_embd + for k,v in kwargs.items(): + setattr(self, k, v) + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + + def forward(self, x, layer_past=None): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + nn.GELU(), + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x): + x = x + self.attn(self.ln1(x)) + x = x + self.mlp(self.ln2(x)) + return x + +class GPT(nn.Module): + """ the full GPT language model, with a context size of block_size """ + + def __init__(self, config): + super().__init__() + + # input embedding stem + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + + self.block_size = config.block_size + self.apply(self._init_weights) + + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def configure_optimizers(self, train_config): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, ) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add('pos_emb') + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + return optimizer + + def forward(self, embeddings): + b, t, c = embeddings.size() + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + + # forward the GPT model + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(embeddings + position_embeddings) + x = self.blocks(x) + + return x \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index b22bf7f5..da87349b 100644 --- a/codes/train.py +++ b/codes/train.py @@ -300,7 +300,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wave_tacotron_diffusion_lj.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()