Get dalle-style TTS to "work"
This commit is contained in:
parent
2814307eee
commit
4c98b9703f
|
@ -15,78 +15,86 @@ from trainer.networks import register_model
|
|||
|
||||
|
||||
class GptTts(nn.Module):
|
||||
NUMBER_SYMBOLS = len(symbols)+3
|
||||
TEXT_START_TOKEN = NUMBER_SYMBOLS-3
|
||||
TEXT_STOP_TOKEN = NUMBER_SYMBOLS-2
|
||||
TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1
|
||||
MEL_DICTIONARY_SIZE = 512+3
|
||||
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
|
||||
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
|
||||
MEL_PAD_TOKEN = MEL_DICTIONARY_SIZE-1
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
number_symbols = len(symbols)
|
||||
model_dim = 512
|
||||
max_symbols_per_phrase = 200
|
||||
max_mel_frames = 900
|
||||
max_mel_frames = 900 * 3 // 8 # The VQVAE outputs 3/8 of the input mel as tokens.
|
||||
mel_dim=80
|
||||
|
||||
self.model_dim = model_dim
|
||||
self.max_mel_frames = max_mel_frames
|
||||
self.text_embedding = nn.Embedding(number_symbols, model_dim)
|
||||
# Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias
|
||||
# into the model which we cannot provide in inference.
|
||||
self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
|
||||
PixelUnshuffle1D(2),
|
||||
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
|
||||
ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d))
|
||||
self.text_embedding = nn.Embedding(self.NUMBER_SYMBOLS, model_dim)
|
||||
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
|
||||
# *_tags are additively applied to
|
||||
self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
|
||||
self.separator = nn.Parameter(torch.randn(1, 1, model_dim))
|
||||
self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0)
|
||||
self.text_preprocess_xformer = GPT(GPTConfig(max_symbols_per_phrase, n_layer=2, n_head=2, n_embd=model_dim))
|
||||
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8))
|
||||
self.text_pos_embedding = nn.Embedding(max_symbols_per_phrase, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
|
||||
self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
|
||||
|
||||
self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
|
||||
PixelShuffle1D(2),
|
||||
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
|
||||
ConvGnSilu(model_dim//2, 1, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
|
||||
self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d),
|
||||
PixelShuffle1D(2),
|
||||
ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d),
|
||||
ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d))
|
||||
#self.postnet = Postnet(munchify(hparams.create_hparams()))
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_SYMBOLS)
|
||||
self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE)
|
||||
|
||||
def forward(self, text_inputs, mel_targets, output_lengths):
|
||||
# Pad mel_targets to be a multiple of 2
|
||||
padded = mel_targets.shape[-1] % 2 != 0
|
||||
if padded:
|
||||
mel_targets = F.pad(mel_targets, (0,1))
|
||||
def forward(self, text_inputs, text_lengths, mel_targets, output_lengths):
|
||||
output_lengths = output_lengths * 3 // 8 # The data we are dealing with has been compressed by the vqvae.
|
||||
# Add the stop tokens to the end of the texts and mels. Theoretically this would be better done at the dataloader level.
|
||||
batch_range = torch.arange(0, text_inputs.shape[0])
|
||||
text_inputs = F.pad(text_inputs, (0,1))
|
||||
text_inputs.index_put_((batch_range, text_lengths), torch.tensor([self.TEXT_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
|
||||
text_lengths = text_lengths + 1
|
||||
mel_targets = F.pad(mel_targets, (0,1))
|
||||
mel_targets.index_put_((batch_range, output_lengths), torch.tensor([self.MEL_STOP_TOKEN], dtype=torch.long, device=text_inputs.device))
|
||||
output_lengths = output_lengths + 1
|
||||
# Add the start tokens to the beginnings of the texts and mels.
|
||||
text_inputs = F.pad(text_inputs, (1,0), value=self.TEXT_START_TOKEN)
|
||||
text_lengths = text_lengths + 1
|
||||
mel_targets = F.pad(mel_targets, (1,0), value=self.MEL_START_TOKEN)
|
||||
output_lengths = output_lengths + 1
|
||||
# Add padding as well. This also should realistically be done at the dataloader level.
|
||||
text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1])
|
||||
text_inputs.data.masked_fill_(text_pad_mask, self.TEXT_PAD_TOKEN)
|
||||
mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1])
|
||||
mel_targets.data.masked_fill_(mel_pad_mask, self.MEL_PAD_TOKEN)
|
||||
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1])
|
||||
text_emb = text_emb + self.text_tags
|
||||
mel_emb = self.mel_encoder(mel_targets).permute(0,2,1)
|
||||
mel_emb = mel_emb + self.audio_tags
|
||||
emb = torch.cat([text_emb,
|
||||
self.separator.repeat(text_emb.shape[0],1,1),
|
||||
mel_emb], dim=1)
|
||||
enc = self.gpt(emb, text_emb.shape[1])
|
||||
mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1)
|
||||
gates = self.gate_head(mel_portion).squeeze(1)
|
||||
mel_pred = self.mel_head(mel_portion)
|
||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||
mel_emb = self.mel_embedding(mel_targets)
|
||||
mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_targets.shape[1], device=mel_targets.device))
|
||||
emb = torch.cat([text_emb, mel_emb], dim=1)
|
||||
enc = self.gpt(emb)
|
||||
|
||||
# Mask portions of output which we don't need to predict.
|
||||
mask = ~get_mask_from_lengths(output_lengths, mel_pred.shape[-1])
|
||||
mask = mask.unsqueeze(1).repeat(1, mel_pred.shape[1], 1)
|
||||
mel_pred.data.masked_fill_(mask, 0)
|
||||
gates.data.masked_fill_(mask[:, 0, :], 1e3)
|
||||
# Compute logits for text and mel heads
|
||||
text_logits = self.final_norm(enc[:, :text_emb.shape[1]])
|
||||
text_logits = self.text_head(text_logits)
|
||||
mel_logits = self.final_norm(enc[:, text_emb.shape[1]:])
|
||||
mel_logits = self.mel_head(mel_logits)
|
||||
|
||||
if padded:
|
||||
mel_pred = mel_pred[:, :, :-1]
|
||||
gates = gates[:, :-1]
|
||||
# Compute loss
|
||||
loss_text = F.cross_entropy(text_logits.permute(0,2,1)[:,:,1:], text_inputs[:,1:], reduction='none')
|
||||
loss_mel = F.cross_entropy(mel_logits.permute(0,2,1)[:,:,1:], mel_targets[:,1:], reduction='none')
|
||||
# Apply a reduction factor across MEL_PAD and TEXT_PAD tokens.
|
||||
pad_loss_reduction_factor = .01
|
||||
loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask[:,1:], pad_loss_reduction_factor)
|
||||
loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask[:,1:], pad_loss_reduction_factor)
|
||||
|
||||
#postnet_mel_pred = self.postnet(mel_pred)
|
||||
#return mel_pred, postnet_mel_pred, gates
|
||||
return mel_pred, gates
|
||||
# Fix up mel_logits so it can go into a VAE decoder as well.
|
||||
mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1)
|
||||
mel_codes = mel_codes[:,1:]
|
||||
mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:], 0)
|
||||
mel_codes = mel_codes[:,:-1]
|
||||
extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD
|
||||
mel_codes = mel_codes * extra_mask
|
||||
|
||||
def test_guide(self, mel_guide, amount=50):
|
||||
mel_guide = mel_guide[:,:,:amount]
|
||||
mel_emb = self.mel_encoder(mel_guide).permute(0,2,1)
|
||||
mel_emb = mel_emb + self.audio_tags
|
||||
return mel_emb
|
||||
return loss_text.mean(), loss_mel.mean(), mel_codes
|
||||
|
||||
def inference(self, text_inputs, mel_guide):
|
||||
MEL_HEAD_EXPANSION = 2
|
||||
|
@ -138,12 +146,11 @@ def register_gpt_tts(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
gpt = GptTts()
|
||||
m, g = gpt(torch.randint(high=24, size=(2,60)),
|
||||
torch.randn(2,80,747),
|
||||
torch.tensor([600,747]))
|
||||
print(m.shape)
|
||||
#print(p.shape)
|
||||
print(g.shape)
|
||||
l1, l2, i = gpt(torch.randint(high=24, size=(2,60)),
|
||||
torch.tensor([55,58]),
|
||||
torch.randint(high=512, size=(2,310)),
|
||||
torch.tensor([300,305]))
|
||||
print(i.shape)
|
||||
|
||||
#o = gpt.infer(torch.randint(high=24, size=(2,60)))
|
||||
#print(o.shape)
|
||||
|
|
|
@ -16,6 +16,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from utils.util import checkpoint, sequential_checkpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GPTConfig:
|
||||
|
@ -56,7 +58,7 @@ class CausalSelfAttention(nn.Module):
|
|||
.view(1, 1, config.block_size, config.block_size))
|
||||
self.n_head = config.n_head
|
||||
|
||||
def forward(self, x, text_block_size):
|
||||
def forward(self, x):
|
||||
B, T, C = x.size()
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
|
@ -66,12 +68,10 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.mask[:,:,:T,:T].logical_or(
|
||||
F.pad(torch.ones((B,self.n_head,text_block_size,text_block_size), device=x.device), (0, T-text_block_size, 0, T-text_block_size))) == 0,
|
||||
float('-inf'))
|
||||
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_drop(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
|
@ -93,19 +93,22 @@ class Block(nn.Module):
|
|||
nn.Dropout(config.resid_pdrop),
|
||||
)
|
||||
|
||||
def forward(self, x, text_block_size):
|
||||
x = x + self.attn(self.ln1(x), text_block_size)
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln1(x))
|
||||
x = x + self.mlp(self.ln2(x))
|
||||
return x
|
||||
|
||||
class GPT(nn.Module):
|
||||
""" the full GPT language model, with a context size of block_size """
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, do_pos_emb=True):
|
||||
super().__init__()
|
||||
|
||||
# input embedding stem
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
||||
if do_pos_emb:
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
||||
else:
|
||||
self.pos_emb = None
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
# transformer
|
||||
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
||||
|
@ -173,14 +176,14 @@ class GPT(nn.Module):
|
|||
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
||||
return optimizer
|
||||
|
||||
def forward(self, embeddings, text_block_sizes):
|
||||
def forward(self, embeddings):
|
||||
b, t, c = embeddings.size()
|
||||
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
|
||||
|
||||
# forward the GPT model
|
||||
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
||||
x = self.drop(embeddings + position_embeddings)
|
||||
for block in self.blocks:
|
||||
x = block(x, text_block_sizes)
|
||||
if self.pos_emb is not None:
|
||||
embeddings = embeddings + self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
||||
x = self.drop(embeddings)
|
||||
x = sequential_checkpoint(self.blocks, 4, x)
|
||||
|
||||
return x
|
|
@ -236,6 +236,16 @@ class VQVAE(nn.Module):
|
|||
|
||||
return quant_t, quant_b, diff_t + diff_b, id_t, id_b
|
||||
|
||||
def encode_only_quantized(self, input):
|
||||
qt, qb, d, idt, idb = self.encode(input)
|
||||
# Interleave top and bottom so top comes first and bottom comes second, such that the output looks like
|
||||
# [t0,b0,b1,t1,b1,b2,t2,b3,b4....]
|
||||
b, s = idt.shape
|
||||
idt = idt.view(b, s, 1)
|
||||
idb = idb.reshape(b, 2, s).permute(0,2,1).contiguous()
|
||||
ids = torch.cat([idt, idb], dim=2).reshape(b, s*3)
|
||||
return ids
|
||||
|
||||
def decode(self, quant_t, quant_b):
|
||||
upsample_t = self.upsample_t(quant_t)
|
||||
quant = torch.cat([upsample_t, quant_b], 1)
|
||||
|
@ -245,14 +255,25 @@ class VQVAE(nn.Module):
|
|||
|
||||
def decode_code(self, code_t, code_b):
|
||||
quant_t = self.quantize_t.embed_code(code_t)
|
||||
quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||
quant_t = quant_t.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1))
|
||||
quant_b = self.quantize_b.embed_code(code_b)
|
||||
quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1))
|
||||
quant_b = quant_b.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1))
|
||||
|
||||
dec = self.decode(quant_t, quant_b)
|
||||
|
||||
return dec
|
||||
|
||||
# Performs decode_code() with the outputs from encode_only_quantized.
|
||||
def decode_code_joined(self, input):
|
||||
b, s = input.shape
|
||||
assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized.
|
||||
s = s // 3
|
||||
|
||||
input = input.reshape(b, s, 3).permute(0,2,1).contiguous()
|
||||
t = input[:,0,:]
|
||||
b = input[:,1:,:].reshape(b, 2*s)
|
||||
return self.decode_code(t, b)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_vqvae(opt_net, opt):
|
||||
|
@ -272,5 +293,7 @@ def register_vqvae_audio(opt_net, opt):
|
|||
|
||||
if __name__ == '__main__':
|
||||
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
|
||||
res=model(torch.randn(1,80,224))
|
||||
#res=model(torch.randn(1,80,2048))
|
||||
e = model.encode_only_quantized(torch.randn(1, 80, 2048))
|
||||
model.decode_code_joined(e)
|
||||
print(res[0].shape)
|
|
@ -51,7 +51,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
|
|
@ -300,7 +300,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -81,7 +81,9 @@ class ConfigurableStep(Module):
|
|||
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
||||
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
|
||||
emb_modules = (nn.Embedding, nn.EmbeddingBag)
|
||||
params_notweights = set()
|
||||
param_names_notweights = set()
|
||||
all_param_names = set()
|
||||
param_map = {}
|
||||
for mn, m in net.named_modules():
|
||||
for k, v in m.named_parameters():
|
||||
v.is_bias = k.endswith(".bias")
|
||||
|
@ -89,8 +91,11 @@ class ConfigurableStep(Module):
|
|||
v.is_norm = isinstance(m, norm_modules)
|
||||
v.is_emb = isinstance(m, emb_modules)
|
||||
|
||||
fpn = '%s.%s' % (mn, k) if mn else k # full param name
|
||||
all_param_names.add(fpn)
|
||||
param_map[fpn] = v
|
||||
if v.is_bias or v.is_norm or v.is_emb:
|
||||
params_notweights.add(v)
|
||||
param_names_notweights.add(fpn)
|
||||
|
||||
# Some models can specify some parameters to be in different groups.
|
||||
param_group = "default"
|
||||
|
@ -106,7 +111,8 @@ class ConfigurableStep(Module):
|
|||
else:
|
||||
if self.env['rank'] <= 0:
|
||||
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
||||
params_weights = set(net.parameters()) ^ params_notweights
|
||||
params_notweights = [param_map[k] for k in sorted(list(param_names_notweights))]
|
||||
params_weights = [param_map[k] for k in sorted(list(all_param_names ^ param_names_notweights))]
|
||||
|
||||
if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam':
|
||||
opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'],
|
||||
|
@ -114,8 +120,8 @@ class ConfigurableStep(Module):
|
|||
betas=(opt_config['beta1'], opt_config['beta2']))
|
||||
elif self.step_opt['optimizer'] == 'adamw':
|
||||
groups = [
|
||||
{ 'params': list(params_weights), 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
|
||||
{ 'params': list(params_notweights), 'weight_decay': 0 }
|
||||
{ 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) },
|
||||
{ 'params': params_notweights, 'weight_decay': 0 }
|
||||
]
|
||||
opt = torch.optim.AdamW(groups, lr=opt_config['lr'],
|
||||
weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2),
|
||||
|
@ -190,6 +196,9 @@ class ConfigurableStep(Module):
|
|||
if no_ddp_sync and hasattr(training_net, 'no_sync'):
|
||||
with training_net.no_sync():
|
||||
injected = inj(local_state)
|
||||
elif opt_get(inj.opt, ['no_grad'], False):
|
||||
with torch.no_grad():
|
||||
injected = inj(local_state)
|
||||
else:
|
||||
injected = inj(local_state)
|
||||
local_state.update(injected)
|
||||
|
|
Loading…
Reference in New Issue
Block a user