Get dalle-style TTS to "work"

This commit is contained in:
James Betker 2021-08-03 21:08:27 -06:00
parent 2814307eee
commit 4c98b9703f
6 changed files with 127 additions and 85 deletions

View File

@ -15,78 +15,86 @@ from trainer.networks import register_model
class GptTts(nn.Module): class GptTts(nn.Module):
NUMBER_SYMBOLS = len(symbols)+3
TEXT_START_TOKEN = NUMBER_SYMBOLS-3
TEXT_STOP_TOKEN = NUMBER_SYMBOLS-2
TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1
MEL_DICTIONARY_SIZE = 512+3
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
MEL_PAD_TOKEN = MEL_DICTIONARY_SIZE-1
def __init__(self): def __init__(self):
super().__init__() super().__init__()
number_symbols = len(symbols)
model_dim = 512 model_dim = 512
max_symbols_per_phrase = 200 max_symbols_per_phrase = 200
max_mel_frames = 900 max_mel_frames = 900 * 3 // 8 # The VQVAE outputs 3/8 of the input mel as tokens.
mel_dim=80 mel_dim=80
self.model_dim = model_dim self.model_dim = model_dim
self.max_mel_frames = max_mel_frames self.max_mel_frames = max_mel_frames
self.text_embedding = nn.Embedding(number_symbols, model_dim) self.text_embedding = nn.Embedding(self.NUMBER_SYMBOLS, model_dim)
# Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
# into the model which we cannot provide in inference.
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
PixelUnshuffle1D(2),
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d))
# *_tags are additively applied to # *_tags are additively applied to
self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) self.text_pos_embedding = nn.Embedding(max_symbols_per_phrase, model_dim)
self.separator = nn.Parameter(torch.randn(1, 1, model_dim)) self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
self.text_preprocess_xformer = GPT(GPTConfig(max_symbols_per_phrase, n_layer=2, n_head=2, n_embd=model_dim))
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8))
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d), self.final_norm = nn.LayerNorm(model_dim)
PixelShuffle1D(2), self.text_head = nn.Linear(model_dim, self.NUMBER_SYMBOLS)
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d), self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE)
ConvGnSilu(model_dim//2, 1, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
PixelShuffle1D(2),
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
#self.postnet = Postnet(munchify(hparams.create_hparams()))
def forward(self, text_inputs, mel_targets, output_lengths): def forward(self, text_inputs, text_lengths, mel_targets, output_lengths):
# Pad mel_targets to be a multiple of 2 output_lengths = output_lengths * 3 // 8 # The data we are dealing with has been compressed by the vqvae.
padded = mel_targets.shape[-1] % 2 != 0 # Add the stop tokens to the end of the texts and mels. Theoretically this would be better done at the dataloader level.
if padded: batch_range = torch.arange(0, text_inputs.shape[0])
text_inputs = F.pad(text_inputs, (0,1))
text_inputs.index_put_((batch_range, text_lengths), torch.tensor([self.TEXT_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
text_lengths = text_lengths + 1
mel_targets = F.pad(mel_targets, (0,1)) mel_targets = F.pad(mel_targets, (0,1))
mel_targets.index_put_((batch_range, output_lengths), torch.tensor([self.MEL_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
output_lengths = output_lengths + 1
# Add the start tokens to the beginnings of the texts and mels.
text_inputs = F.pad(text_inputs, (1,0), value=self.TEXT_START_TOKEN)
text_lengths = text_lengths + 1
mel_targets = F.pad(mel_targets, (1,0), value=self.MEL_START_TOKEN)
output_lengths = output_lengths + 1
# Add padding as well. This also should realistically be done at the dataloader level.
text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1])
text_inputs.data.masked_fill_(text_pad_mask, self.TEXT_PAD_TOKEN)
mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1])
mel_targets.data.masked_fill_(mel_pad_mask, self.MEL_PAD_TOKEN)
text_emb = self.text_embedding(text_inputs) text_emb = self.text_embedding(text_inputs)
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1]) text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
text_emb = text_emb + self.text_tags mel_emb = self.mel_embedding(mel_targets)
mel_emb = self.mel_encoder(mel_targets).permute(0,2,1) mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_targets.shape[1], device=mel_targets.device))
mel_emb = mel_emb + self.audio_tags emb = torch.cat([text_emb, mel_emb], dim=1)
emb = torch.cat([text_emb, enc = self.gpt(emb)
self.separator.repeat(text_emb.shape[0],1,1),
mel_emb], dim=1)
enc = self.gpt(emb, text_emb.shape[1])
mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1)
gates = self.gate_head(mel_portion).squeeze(1)
mel_pred = self.mel_head(mel_portion)
# Mask portions of output which we don't need to predict. # Compute logits for text and mel heads
mask = ~get_mask_from_lengths(output_lengths, mel_pred.shape[-1]) text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
mask = mask.unsqueeze(1).repeat(1, mel_pred.shape[1], 1) text_logits = self.text_head(text_logits)
mel_pred.data.masked_fill_(mask, 0) mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
gates.data.masked_fill_(mask[:, 0, :], 1e3) mel_logits = self.mel_head(mel_logits)
if padded: # Compute loss
mel_pred = mel_pred[:, :, :-1] loss_text = F.cross_entropy(text_logits.permute(0,2,1)[:,:,1:], text_inputs[:,1:], reduction='none')
gates = gates[:, :-1] loss_mel = F.cross_entropy(mel_logits.permute(0,2,1)[:,:,1:], mel_targets[:,1:], reduction='none')
# Apply a reduction factor across MEL_PAD and TEXT_PAD tokens.
pad_loss_reduction_factor = .01
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask[:,1:], pad_loss_reduction_factor)
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask[:,1:], pad_loss_reduction_factor)
#postnet_mel_pred = self.postnet(mel_pred) # Fix up mel_logits so it can go into a VAE decoder as well.
#return mel_pred, postnet_mel_pred, gates mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
return mel_pred, gates mel_codes = mel_codes[:,1:]
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:], 0)
mel_codes = mel_codes[:,:-1]
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
mel_codes = mel_codes * extra_mask
def test_guide(self, mel_guide, amount=50): return loss_text.mean(), loss_mel.mean(), mel_codes
mel_guide = mel_guide[:,:,:amount]
mel_emb = self.mel_encoder(mel_guide).permute(0,2,1)
mel_emb = mel_emb + self.audio_tags
return mel_emb
def inference(self, text_inputs, mel_guide): def inference(self, text_inputs, mel_guide):
MEL_HEAD_EXPANSION = 2 MEL_HEAD_EXPANSION = 2
@ -138,12 +146,11 @@ def register_gpt_tts(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
gpt = GptTts() gpt = GptTts()
m, g = gpt(torch.randint(high=24, size=(2,60)), l1, l2, i = gpt(torch.randint(high=24, size=(2,60)),
torch.randn(2,80,747), torch.tensor([55,58]),
torch.tensor([600,747])) torch.randint(high=512, size=(2,310)),
print(m.shape) torch.tensor([300,305]))
#print(p.shape) print(i.shape)
print(g.shape)
#o = gpt.infer(torch.randint(high=24, size=(2,60))) #o = gpt.infer(torch.randint(high=24, size=(2,60)))
#print(o.shape) #print(o.shape)

View File

@ -16,6 +16,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from utils.util import checkpoint, sequential_checkpoint
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GPTConfig: class GPTConfig:
@ -56,7 +58,7 @@ class CausalSelfAttention(nn.Module):
.view(1, 1, config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head self.n_head = config.n_head
def forward(self, x, text_block_size): def forward(self, x):
B, T, C = x.size() B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim # calculate query, key, values for all heads in batch and move head forward to be the batch dim
@ -66,9 +68,7 @@ class CausalSelfAttention(nn.Module):
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T].logical_or( att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
F.pad(torch.ones((B,self.n_head,text_block_size,text_block_size), device=x.device), (0, T-text_block_size, 0, T-text_block_size))) == 0,
float('-inf'))
att = F.softmax(att, dim=-1) att = F.softmax(att, dim=-1)
att = self.attn_drop(att) att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
@ -93,19 +93,22 @@ class Block(nn.Module):
nn.Dropout(config.resid_pdrop), nn.Dropout(config.resid_pdrop),
) )
def forward(self, x, text_block_size): def forward(self, x):
x = x + self.attn(self.ln1(x), text_block_size) x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x)) x = x + self.mlp(self.ln2(x))
return x return x
class GPT(nn.Module): class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """ """ the full GPT language model, with a context size of block_size """
def __init__(self, config): def __init__(self, config, do_pos_emb=True):
super().__init__() super().__init__()
# input embedding stem # input embedding stem
if do_pos_emb:
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
else:
self.pos_emb = None
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
# transformer # transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
@ -173,14 +176,14 @@ class GPT(nn.Module):
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
return optimizer return optimizer
def forward(self, embeddings, text_block_sizes): def forward(self, embeddings):
b, t, c = embeddings.size() b, t, c = embeddings.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted." assert t <= self.block_size, "Cannot forward, model block size is exhausted."
# forward the GPT model # forward the GPT model
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector if self.pos_emb is not None:
x = self.drop(embeddings + position_embeddings) embeddings = embeddings + self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
for block in self.blocks: x = self.drop(embeddings)
x = block(x, text_block_sizes) x = sequential_checkpoint(self.blocks, 4, x)
return x return x

View File

@ -236,6 +236,16 @@ class VQVAE(nn.Module):
return quant_t, quant_b, diff_t + diff_b, id_t, id_b return quant_t, quant_b, diff_t + diff_b, id_t, id_b
def encode_only_quantized(self, input):
qt, qb, d, idt, idb = self.encode(input)
# Interleave top and bottom so top comes first and bottom comes second, such that the output looks like
# [t0,b0,b1,t1,b1,b2,t2,b3,b4....]
b, s = idt.shape
idt = idt.view(b, s, 1)
idb = idb.reshape(b, 2, s).permute(0,2,1).contiguous()
ids = torch.cat([idt, idb], dim=2).reshape(b, s*3)
return ids
def decode(self, quant_t, quant_b): def decode(self, quant_t, quant_b):
upsample_t = self.upsample_t(quant_t) upsample_t = self.upsample_t(quant_t)
quant = torch.cat([upsample_t, quant_b], 1) quant = torch.cat([upsample_t, quant_b], 1)
@ -245,14 +255,25 @@ class VQVAE(nn.Module):
def decode_code(self, code_t, code_b): def decode_code(self, code_t, code_b):
quant_t = self.quantize_t.embed_code(code_t) quant_t = self.quantize_t.embed_code(code_t)
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) quant_t = quant_t.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1))
quant_b = self.quantize_b.embed_code(code_b) quant_b = self.quantize_b.embed_code(code_b)
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) quant_b = quant_b.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1))
dec = self.decode(quant_t, quant_b) dec = self.decode(quant_t, quant_b)
return dec return dec
# Performs decode_code() with the outputs from encode_only_quantized.
def decode_code_joined(self, input):
b, s = input.shape
assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized.
s = s // 3
input = input.reshape(b, s, 3).permute(0,2,1).contiguous()
t = input[:,0,:]
b = input[:,1:,:].reshape(b, 2*s)
return self.decode_code(t, b)
@register_model @register_model
def register_vqvae(opt_net, opt): def register_vqvae(opt_net, opt):
@ -272,5 +293,7 @@ def register_vqvae_audio(opt_net, opt):
if __name__ == '__main__': if __name__ == '__main__':
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d) model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
res=model(torch.randn(1,80,224)) #res=model(torch.randn(1,80,2048))
e = model.encode_only_quantized(torch.randn(1, 80, 2048))
model.decode_code_joined(e)
print(res[0].shape) print(res[0].shape)

View File

@ -51,7 +51,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
want_metrics = False want_metrics = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt utils.util.loaded_options = opt

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -81,7 +81,9 @@ class ConfigurableStep(Module):
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
emb_modules = (nn.Embedding, nn.EmbeddingBag) emb_modules = (nn.Embedding, nn.EmbeddingBag)
params_notweights = set() param_names_notweights = set()
all_param_names = set()
param_map = {}
for mn, m in net.named_modules(): for mn, m in net.named_modules():
for k, v in m.named_parameters(): for k, v in m.named_parameters():
v.is_bias = k.endswith(".bias") v.is_bias = k.endswith(".bias")
@ -89,8 +91,11 @@ class ConfigurableStep(Module):
v.is_norm = isinstance(m, norm_modules) v.is_norm = isinstance(m, norm_modules)
v.is_emb = isinstance(m, emb_modules) v.is_emb = isinstance(m, emb_modules)
fpn = '%s.%s' % (mn, k) if mn else k # full param name
all_param_names.add(fpn)
param_map[fpn] = v
if v.is_bias or v.is_norm or v.is_emb: if v.is_bias or v.is_norm or v.is_emb:
params_notweights.add(v) param_names_notweights.add(fpn)
# Some models can specify some parameters to be in different groups. # Some models can specify some parameters to be in different groups.
param_group = "default" param_group = "default"
@ -106,7 +111,8 @@ class ConfigurableStep(Module):
else: else:
if self.env['rank'] <= 0: if self.env['rank'] <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k)) logger.warning('Params [{:s}] will not optimize.'.format(k))
params_weights = set(net.parameters()) ^ params_notweights params_notweights = [param_map[k] for k in sorted(list(param_names_notweights))]
params_weights = [param_map[k] for k in sorted(list(all_param_names ^ param_names_notweights))]
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam': if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'], opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'],
@ -114,8 +120,8 @@ class ConfigurableStep(Module):
betas=(opt_config['beta1'], opt_config['beta2'])) betas=(opt_config['beta1'], opt_config['beta2']))
elif self.step_opt['optimizer'] == 'adamw': elif self.step_opt['optimizer'] == 'adamw':
groups = [ groups = [
{ 'params': list(params_weights), 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) }, { 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
{ 'params': list(params_notweights), 'weight_decay': 0 } { 'params': params_notweights, 'weight_decay': 0 }
] ]
opt = torch.optim.AdamW(groups, lr=opt_config['lr'], opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
@ -190,6 +196,9 @@ class ConfigurableStep(Module):
if no_ddp_sync and hasattr(training_net, 'no_sync'): if no_ddp_sync and hasattr(training_net, 'no_sync'):
with training_net.no_sync(): with training_net.no_sync():
injected = inj(local_state) injected = inj(local_state)
elif opt_get(inj.opt, ['no_grad'], False):
with torch.no_grad():
injected = inj(local_state)
else: else:
injected = inj(local_state) injected = inj(local_state)
local_state.update(injected) local_state.update(injected)