forked from mrq/DL-Art-School
Add gpt_tts
This commit is contained in:
parent
398185e109
commit
dadc54795c
|
@ -3,11 +3,13 @@ import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import models.tacotron2.layers as layers
|
import models.tacotron2.layers as layers
|
||||||
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
|
from models.tacotron2.taco_utils import load_wav_to_torch, load_filepaths_and_text
|
||||||
|
|
||||||
from models.tacotron2.text import text_to_sequence
|
from models.tacotron2.text import text_to_sequence
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
class TextMelLoader(torch.utils.data.Dataset):
|
class TextMelLoader(torch.utils.data.Dataset):
|
||||||
|
@ -23,8 +25,8 @@ class TextMelLoader(torch.utils.data.Dataset):
|
||||||
self.max_wav_value = hparams.max_wav_value
|
self.max_wav_value = hparams.max_wav_value
|
||||||
self.sampling_rate = hparams.sampling_rate
|
self.sampling_rate = hparams.sampling_rate
|
||||||
self.load_mel_from_disk = hparams.load_mel_from_disk
|
self.load_mel_from_disk = hparams.load_mel_from_disk
|
||||||
self.return_wavs = hparams.return_wavs
|
self.return_wavs = opt_get(hparams, ['return_wavs'], False)
|
||||||
self.input_sample_rate = hparams.input_sample_rate
|
self.input_sample_rate = opt_get(hparams, ['input_sample_rate'], self.sampling_rate)
|
||||||
assert not (self.load_mel_from_disk and self.return_wavs)
|
assert not (self.load_mel_from_disk and self.return_wavs)
|
||||||
self.stft = layers.TacotronSTFT(
|
self.stft = layers.TacotronSTFT(
|
||||||
hparams.filter_length, hparams.hop_length, hparams.win_length,
|
hparams.filter_length, hparams.hop_length, hparams.win_length,
|
||||||
|
@ -134,10 +136,10 @@ if __name__ == '__main__':
|
||||||
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt',
|
||||||
'phase': 'train',
|
'phase': 'train',
|
||||||
'n_workers': 0,
|
'n_workers': 0,
|
||||||
'batch_size': 2,
|
'batch_size': 16,
|
||||||
'return_wavs': True,
|
#'return_wavs': True,
|
||||||
'input_sample_rate': 22050,
|
#'input_sample_rate': 22050,
|
||||||
'sampling_rate': 8000
|
#'sampling_rate': 8000
|
||||||
}
|
}
|
||||||
from data import create_dataset, create_dataloader
|
from data import create_dataset, create_dataloader
|
||||||
|
|
||||||
|
@ -145,10 +147,10 @@ if __name__ == '__main__':
|
||||||
dl = create_dataloader(ds, params, collate_fn=c)
|
dl = create_dataloader(ds, params, collate_fn=c)
|
||||||
i = 0
|
i = 0
|
||||||
m = []
|
m = []
|
||||||
for b in dl:
|
max_text = 0
|
||||||
m.append(b)
|
max_mel = 0
|
||||||
i += 1
|
for b in tqdm(dl):
|
||||||
if i > 9999:
|
max_mel = max(max_mel, b['padded_mel'].shape[2])
|
||||||
break
|
max_text = max(max_text, b['padded_text'].shape[1])
|
||||||
m=torch.stack(m)
|
m=torch.stack(m)
|
||||||
print(m.mean(), m.std())
|
print(m.mean(), m.std())
|
||||||
|
|
0
codes/models/gpt_voice/__init__.py
Normal file
0
codes/models/gpt_voice/__init__.py
Normal file
77
codes/models/gpt_voice/gpt_tts.py
Normal file
77
codes/models/gpt_voice/gpt_tts.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from models.arch_util import ConvGnSilu
|
||||||
|
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||||
|
from models.tacotron2.text import symbols
|
||||||
|
from models.gpt_voice.min_gpt import GPT, GPTConfig
|
||||||
|
from trainer.networks import register_model
|
||||||
|
|
||||||
|
|
||||||
|
class GptTts(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
number_symbols = len(symbols)
|
||||||
|
model_dim = 512
|
||||||
|
max_symbols_per_phrase = 200
|
||||||
|
max_mel_frames = 900
|
||||||
|
mel_dim=80
|
||||||
|
|
||||||
|
self.text_embedding = nn.Embedding(number_symbols, model_dim)
|
||||||
|
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=3, convnd=nn.Conv1d),
|
||||||
|
ConvGnSilu(model_dim//2, model_dim, kernel_size=3, stride=2, convnd=nn.Conv1d))
|
||||||
|
self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
|
||||||
|
self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
|
||||||
|
self.gpt = GPT(GPTConfig(max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8))
|
||||||
|
|
||||||
|
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=nn.Conv1d),
|
||||||
|
nn.Upsample(scale_factor=2, mode='nearest'),
|
||||||
|
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=nn.Conv1d),
|
||||||
|
nn.Conv1d(model_dim//2, 1, kernel_size=1))
|
||||||
|
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=nn.Conv1d),
|
||||||
|
nn.Upsample(scale_factor=2, mode='nearest'),
|
||||||
|
ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=nn.Conv1d),
|
||||||
|
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=5, convnd=nn.Conv1d),
|
||||||
|
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, activation=False, norm=False, convnd=nn.Conv1d))
|
||||||
|
|
||||||
|
def forward(self, text_inputs, mel_targets, output_lengths):
|
||||||
|
# Pad mel_targets to be a multiple of 2
|
||||||
|
padded = mel_targets.shape[-1] % 2 != 0
|
||||||
|
if padded:
|
||||||
|
mel_targets = F.pad(mel_targets, (0,1))
|
||||||
|
|
||||||
|
text_emb = self.text_embedding(text_inputs)
|
||||||
|
text_emb = text_emb + self.text_tags
|
||||||
|
mel_emb = self.mel_encoder(mel_targets).permute(0,2,1)
|
||||||
|
mel_emb = mel_emb + self.audio_tags
|
||||||
|
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||||
|
enc = self.gpt(emb)
|
||||||
|
mel_portion = enc[:, text_emb.shape[1]:].permute(0,2,1)
|
||||||
|
gates = self.gate_head(mel_portion).squeeze(1)
|
||||||
|
mel_pred = self.mel_head(mel_portion)
|
||||||
|
|
||||||
|
# Mask portions of output which we don't need to predict.
|
||||||
|
mask = ~get_mask_from_lengths(output_lengths, mel_pred.shape[-1])
|
||||||
|
mask = mask.unsqueeze(1).repeat(1, mel_pred.shape[1], 1)
|
||||||
|
mel_pred.data.masked_fill_(mask, 0)
|
||||||
|
gates.data.masked_fill_(mask[:, 0, :], 1e3)
|
||||||
|
|
||||||
|
if padded:
|
||||||
|
mel_pred = mel_pred[:, :, :-1]
|
||||||
|
gates = gates[:, :-1]
|
||||||
|
return mel_pred, gates
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_gpt_tts(opt_net, opt):
|
||||||
|
return GptTts()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
gpt = GptTts()
|
||||||
|
m, g = gpt(torch.randint(high=24, size=(2,60)),
|
||||||
|
torch.randn(2,80,747),
|
||||||
|
torch.tensor([600,747]))
|
||||||
|
print(m.shape)
|
||||||
|
print(g.shape)
|
183
codes/models/gpt_voice/min_gpt.py
Normal file
183
codes/models/gpt_voice/min_gpt.py
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
"""
|
||||||
|
GPT model:
|
||||||
|
- the initial stem consists of a combination of token encoding and a positional encoding
|
||||||
|
- the meat of it is a uniform sequence of Transformer blocks
|
||||||
|
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
|
||||||
|
- all blocks feed into a central residual pathway similar to resnets
|
||||||
|
- the final decoder is a linear projection into a vanilla Softmax classifier
|
||||||
|
|
||||||
|
Original author: karpathy@, https://github.com/karpathy/minGPT
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class GPTConfig:
|
||||||
|
""" base GPT config, params common to all GPT versions """
|
||||||
|
embd_pdrop = 0.1
|
||||||
|
resid_pdrop = 0.1
|
||||||
|
attn_pdrop = 0.1
|
||||||
|
|
||||||
|
def __init__(self, block_size, n_layer=12, n_head=12, n_embd=768, **kwargs):
|
||||||
|
self.block_size = block_size
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.n_head = n_head
|
||||||
|
self.n_embd = n_embd
|
||||||
|
for k,v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
class CausalSelfAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
A vanilla multi-head masked self-attention layer with a projection at the end.
|
||||||
|
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
||||||
|
explicit implementation here to show that there is nothing too scary here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
assert config.n_embd % config.n_head == 0
|
||||||
|
# key, query, value projections for all heads
|
||||||
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
# regularization
|
||||||
|
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
||||||
|
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
||||||
|
# output projection
|
||||||
|
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
||||||
|
# causal mask to ensure that attention is only applied to the left in the input sequence
|
||||||
|
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||||
|
.view(1, 1, config.block_size, config.block_size))
|
||||||
|
self.n_head = config.n_head
|
||||||
|
|
||||||
|
def forward(self, x, layer_past=None):
|
||||||
|
B, T, C = x.size()
|
||||||
|
|
||||||
|
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||||
|
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||||
|
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||||
|
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||||
|
|
||||||
|
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||||
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||||
|
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
||||||
|
att = F.softmax(att, dim=-1)
|
||||||
|
att = self.attn_drop(att)
|
||||||
|
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||||
|
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||||
|
|
||||||
|
# output projection
|
||||||
|
y = self.resid_drop(self.proj(y))
|
||||||
|
return y
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
""" an unassuming Transformer block """
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.ln1 = nn.LayerNorm(config.n_embd)
|
||||||
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
||||||
|
self.attn = CausalSelfAttention(config)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(config.n_embd, 4 * config.n_embd),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(4 * config.n_embd, config.n_embd),
|
||||||
|
nn.Dropout(config.resid_pdrop),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x + self.attn(self.ln1(x))
|
||||||
|
x = x + self.mlp(self.ln2(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
class GPT(nn.Module):
|
||||||
|
""" the full GPT language model, with a context size of block_size """
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# input embedding stem
|
||||||
|
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
||||||
|
self.drop = nn.Dropout(config.embd_pdrop)
|
||||||
|
# transformer
|
||||||
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
||||||
|
|
||||||
|
self.block_size = config.block_size
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
|
||||||
|
|
||||||
|
def get_block_size(self):
|
||||||
|
return self.block_size
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||||
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||||
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
elif isinstance(module, nn.LayerNorm):
|
||||||
|
module.bias.data.zero_()
|
||||||
|
module.weight.data.fill_(1.0)
|
||||||
|
|
||||||
|
def configure_optimizers(self, train_config):
|
||||||
|
"""
|
||||||
|
This long function is unfortunately doing something very simple and is being very defensive:
|
||||||
|
We are separating out all parameters of the model into two buckets: those that will experience
|
||||||
|
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
||||||
|
We are then returning the PyTorch optimizer object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# separate out all parameters to those that will and won't experience regularizing weight decay
|
||||||
|
decay = set()
|
||||||
|
no_decay = set()
|
||||||
|
whitelist_weight_modules = (torch.nn.Linear, )
|
||||||
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||||
|
for mn, m in self.named_modules():
|
||||||
|
for pn, p in m.named_parameters():
|
||||||
|
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||||
|
|
||||||
|
if pn.endswith('bias'):
|
||||||
|
# all biases will not be decayed
|
||||||
|
no_decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||||
|
# weights of whitelist modules will be weight decayed
|
||||||
|
decay.add(fpn)
|
||||||
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||||
|
# weights of blacklist modules will NOT be weight decayed
|
||||||
|
no_decay.add(fpn)
|
||||||
|
|
||||||
|
# special case the position embedding parameter in the root GPT module as not decayed
|
||||||
|
no_decay.add('pos_emb')
|
||||||
|
|
||||||
|
# validate that we considered every parameter
|
||||||
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||||
|
inter_params = decay & no_decay
|
||||||
|
union_params = decay | no_decay
|
||||||
|
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
||||||
|
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
||||||
|
% (str(param_dict.keys() - union_params), )
|
||||||
|
|
||||||
|
# create the pytorch optimizer object
|
||||||
|
optim_groups = [
|
||||||
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
||||||
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def forward(self, embeddings):
|
||||||
|
b, t, c = embeddings.size()
|
||||||
|
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
||||||
|
|
||||||
|
# forward the GPT model
|
||||||
|
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
||||||
|
x = self.drop(embeddings + position_embeddings)
|
||||||
|
x = self.blocks(x)
|
||||||
|
|
||||||
|
return x
|
|
@ -300,7 +300,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wave_tacotron_diffusion_lj.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user