Get dalle-style TTS to "work"

This commit is contained in:
James Betker 2021-08-03 21:08:27 -06:00
parent 2814307eee
commit 4c98b9703f
6 changed files with 127 additions and 85 deletions

View File

@ -15,78 +15,86 @@ from trainer.networks import register_model
class GptTts(nn.Module):
NUMBER_SYMBOLS = len(symbols)+3
TEXT_START_TOKEN = NUMBER_SYMBOLS-3
TEXT_STOP_TOKEN = NUMBER_SYMBOLS-2
TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1
MEL_DICTIONARY_SIZE = 512+3
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
MEL_PAD_TOKEN = MEL_DICTIONARY_SIZE-1
def __init__(self):
super().__init__()
number_symbols = len(symbols)
model_dim = 512
max_symbols_per_phrase = 200
max_mel_frames = 900
max_mel_frames = 900 * 3 // 8 # The VQVAE outputs 3/8 of the input mel as tokens.
mel_dim=80
self.model_dim = model_dim
self.max_mel_frames = max_mel_frames
self.text_embedding = nn.Embedding(number_symbols, model_dim)
# Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias
# into the model which we cannot provide in inference.
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
PixelUnshuffle1D(2),
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d))
self.text_embedding = nn.Embedding(self.NUMBER_SYMBOLS, model_dim)
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
# *_tags are additively applied to
self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
self.separator = nn.Parameter(torch.randn(1, 1, model_dim))
self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
self.text_preprocess_xformer = GPT(GPTConfig(max_symbols_per_phrase, n_layer=2, n_head=2, n_embd=model_dim))
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8))
self.text_pos_embedding = nn.Embedding(max_symbols_per_phrase, model_dim)
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
PixelShuffle1D(2),
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim//2, 1, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
PixelShuffle1D(2),
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
#self.postnet = Postnet(munchify(hparams.create_hparams()))
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_SYMBOLS)
self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE)
def forward(self, text_inputs, mel_targets, output_lengths):
# Pad mel_targets to be a multiple of 2
padded = mel_targets.shape[-1] % 2 != 0
if padded:
mel_targets = F.pad(mel_targets, (0,1))
def forward(self, text_inputs, text_lengths, mel_targets, output_lengths):
output_lengths = output_lengths * 3 // 8 # The data we are dealing with has been compressed by the vqvae.
# Add the stop tokens to the end of the texts and mels. Theoretically this would be better done at the dataloader level.
batch_range = torch.arange(0, text_inputs.shape[0])
text_inputs = F.pad(text_inputs, (0,1))
text_inputs.index_put_((batch_range, text_lengths), torch.tensor([self.TEXT_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
text_lengths = text_lengths + 1
mel_targets = F.pad(mel_targets, (0,1))
mel_targets.index_put_((batch_range, output_lengths), torch.tensor([self.MEL_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
output_lengths = output_lengths + 1
# Add the start tokens to the beginnings of the texts and mels.
text_inputs = F.pad(text_inputs, (1,0), value=self.TEXT_START_TOKEN)
text_lengths = text_lengths + 1
mel_targets = F.pad(mel_targets, (1,0), value=self.MEL_START_TOKEN)
output_lengths = output_lengths + 1
# Add padding as well. This also should realistically be done at the dataloader level.
text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1])
text_inputs.data.masked_fill_(text_pad_mask, self.TEXT_PAD_TOKEN)
mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1])
mel_targets.data.masked_fill_(mel_pad_mask, self.MEL_PAD_TOKEN)
text_emb = self.text_embedding(text_inputs)
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
text_emb = text_emb + self.text_tags
mel_emb = self.mel_encoder(mel_targets).permute(0,2,1)
mel_emb = mel_emb + self.audio_tags
emb = torch.cat([text_emb,
self.separator.repeat(text_emb.shape[0],1,1),
mel_emb], dim=1)
enc = self.gpt(emb, text_emb.shape[1])
mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1)
gates = self.gate_head(mel_portion).squeeze(1)
mel_pred = self.mel_head(mel_portion)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
mel_emb = self.mel_embedding(mel_targets)
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_targets.shape[1], device=mel_targets.device))
emb = torch.cat([text_emb, mel_emb], dim=1)
enc = self.gpt(emb)
# Mask portions of output which we don't need to predict.
mask = ~get_mask_from_lengths(output_lengths, mel_pred.shape[-1])
mask = mask.unsqueeze(1).repeat(1, mel_pred.shape[1], 1)
mel_pred.data.masked_fill_(mask, 0)
gates.data.masked_fill_(mask[:, 0, :], 1e3)
# Compute logits for text and mel heads
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
text_logits = self.text_head(text_logits)
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
mel_logits = self.mel_head(mel_logits)
if padded:
mel_pred = mel_pred[:, :, :-1]
gates = gates[:, :-1]
# Compute loss
loss_text = F.cross_entropy(text_logits.permute(0,2,1)[:,:,1:], text_inputs[:,1:], reduction='none')
loss_mel = F.cross_entropy(mel_logits.permute(0,2,1)[:,:,1:], mel_targets[:,1:], reduction='none')
# Apply a reduction factor across MEL_PAD and TEXT_PAD tokens.
pad_loss_reduction_factor = .01
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask[:,1:], pad_loss_reduction_factor)
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask[:,1:], pad_loss_reduction_factor)
#postnet_mel_pred = self.postnet(mel_pred)
#return mel_pred, postnet_mel_pred, gates
return mel_pred, gates
# Fix up mel_logits so it can go into a VAE decoder as well.
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
mel_codes = mel_codes[:,1:]
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:], 0)
mel_codes = mel_codes[:,:-1]
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
mel_codes = mel_codes * extra_mask
def test_guide(self, mel_guide, amount=50):
mel_guide = mel_guide[:,:,:amount]
mel_emb = self.mel_encoder(mel_guide).permute(0,2,1)
mel_emb = mel_emb + self.audio_tags
return mel_emb
return loss_text.mean(), loss_mel.mean(), mel_codes
def inference(self, text_inputs, mel_guide):
MEL_HEAD_EXPANSION = 2
@ -138,12 +146,11 @@ def register_gpt_tts(opt_net, opt):
if __name__ == '__main__':
gpt = GptTts()
m, g = gpt(torch.randint(high=24, size=(2,60)),
torch.randn(2,80,747),
torch.tensor([600,747]))
print(m.shape)
#print(p.shape)
print(g.shape)
l1, l2, i = gpt(torch.randint(high=24, size=(2,60)),
torch.tensor([55,58]),
torch.randint(high=512, size=(2,310)),
torch.tensor([300,305]))
print(i.shape)
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape)

View File

@ -16,6 +16,8 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from utils.util import checkpoint, sequential_checkpoint
logger = logging.getLogger(__name__)
class GPTConfig:
@ -56,7 +58,7 @@ class CausalSelfAttention(nn.Module):
.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x, text_block_size):
def forward(self, x):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
@ -66,12 +68,10 @@ class CausalSelfAttention(nn.Module):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T].logical_or(
F.pad(torch.ones((B,self.n_head,text_block_size,text_block_size), device=x.device), (0, T-text_block_size, 0, T-text_block_size))) == 0,
float('-inf'))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
@ -93,19 +93,22 @@ class Block(nn.Module):
nn.Dropout(config.resid_pdrop),
)
def forward(self, x, text_block_size):
x = x + self.attn(self.ln1(x), text_block_size)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """
def __init__(self, config):
def __init__(self, config, do_pos_emb=True):
super().__init__()
# input embedding stem
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
if do_pos_emb:
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
else:
self.pos_emb = None
self.drop = nn.Dropout(config.embd_pdrop)
# transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
@ -173,14 +176,14 @@ class GPT(nn.Module):
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
return optimizer
def forward(self, embeddings, text_block_sizes):
def forward(self, embeddings):
b, t, c = embeddings.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
# forward the GPT model
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(embeddings + position_embeddings)
for block in self.blocks:
x = block(x, text_block_sizes)
if self.pos_emb is not None:
embeddings = embeddings + self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(embeddings)
x = sequential_checkpoint(self.blocks, 4, x)
return x

View File

@ -236,6 +236,16 @@ class VQVAE(nn.Module):
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def encode_only_quantized(self, input):
qt, qb, d, idt, idb = self.encode(input)
# Interleave top and bottom so top comes first and bottom comes second, such that the output looks like
# [t0,b0,b1,t1,b1,b2,t2,b3,b4....]
b, s = idt.shape
idt = idt.view(b, s, 1)
idb = idb.reshape(b, 2, s).permute(0,2,1).contiguous()
ids = torch.cat([idt, idb], dim=2).reshape(b, s*3)
return ids
def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1)
@ -245,14 +255,25 @@ class VQVAE(nn.Module):
def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
quant_t = quant_t.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1))
quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
quant_b = quant_b.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1))
dec = self.decode(quant_t, quant_b)
return dec
# Performs decode_code() with the outputs from encode_only_quantized.
def decode_code_joined(self, input):
b, s = input.shape
assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized.
s = s // 3
input = input.reshape(b, s, 3).permute(0,2,1).contiguous()
t = input[:,0,:]
b = input[:,1:,:].reshape(b, 2*s)
return self.decode_code(t, b)
@register_model
def register_vqvae(opt_net, opt):
@ -272,5 +293,7 @@ def register_vqvae_audio(opt_net, opt):
if __name__ == '__main__':
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
res=model(torch.randn(1,80,224))
#res=model(torch.randn(1,80,2048))
e = model.encode_only_quantized(torch.randn(1, 80, 2048))
model.decode_code_joined(e)
print(res[0].shape)

View File

@ -51,7 +51,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -81,7 +81,9 @@ class ConfigurableStep(Module):
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
emb_modules = (nn.Embedding, nn.EmbeddingBag)
params_notweights = set()
param_names_notweights = set()
all_param_names = set()
param_map = {}
for mn, m in net.named_modules():
for k, v in m.named_parameters():
v.is_bias = k.endswith(".bias")
@ -89,8 +91,11 @@ class ConfigurableStep(Module):
v.is_norm = isinstance(m, norm_modules)
v.is_emb = isinstance(m, emb_modules)
fpn = '%s.%s' % (mn, k) if mn else k # full param name
all_param_names.add(fpn)
param_map[fpn] = v
if v.is_bias or v.is_norm or v.is_emb:
params_notweights.add(v)
param_names_notweights.add(fpn)
# Some models can specify some parameters to be in different groups.
param_group = "default"
@ -106,7 +111,8 @@ class ConfigurableStep(Module):
else:
if self.env['rank'] <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
params_weights = set(net.parameters()) ^ params_notweights
params_notweights = [param_map[k] for k in sorted(list(param_names_notweights))]
params_weights = [param_map[k] for k in sorted(list(all_param_names ^ param_names_notweights))]
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'],
@ -114,8 +120,8 @@ class ConfigurableStep(Module):
betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'adamw':
groups = [
{ 'params': list(params_weights), 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
{ 'params': list(params_notweights), 'weight_decay': 0 }
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
{ 'params': params_notweights, 'weight_decay': 0 }
]
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
@ -190,6 +196,9 @@ class ConfigurableStep(Module):
if no_ddp_sync and hasattr(training_net, 'no_sync'):
with training_net.no_sync():
injected = inj(local_state)
elif opt_get(inj.opt, ['no_grad'], False):
with torch.no_grad():
injected = inj(local_state)
else:
injected = inj(local_state)
local_state.update(injected)