From 4c98b9703fe837379a079cc3b2f428683799f0b5 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 3 Aug 2021 21:08:27 -0600 Subject: [PATCH] Get dalle-style TTS to "work" --- codes/models/gpt_voice/gpt_tts.py | 129 ++++++++++++++------------ codes/models/gpt_voice/min_gpt.py | 31 ++++--- codes/models/vqvae/vqvae.py | 29 +++++- codes/scripts/audio/test_audio_gen.py | 2 +- codes/train.py | 2 +- codes/trainer/steps.py | 19 +++- 6 files changed, 127 insertions(+), 85 deletions(-) diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index 27e9c1e2..ed06fd24 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -15,78 +15,86 @@ from trainer.networks import register_model class GptTts(nn.Module): + NUMBER_SYMBOLS = len(symbols)+3 + TEXT_START_TOKEN = NUMBER_SYMBOLS-3 + TEXT_STOP_TOKEN = NUMBER_SYMBOLS-2 + TEXT_PAD_TOKEN = NUMBER_SYMBOLS-1 + MEL_DICTIONARY_SIZE = 512+3 + MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3 + MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2 + MEL_PAD_TOKEN = MEL_DICTIONARY_SIZE-1 + def __init__(self): super().__init__() - number_symbols = len(symbols) model_dim = 512 max_symbols_per_phrase = 200 - max_mel_frames = 900 + max_mel_frames = 900 * 3 // 8 # The VQVAE outputs 3/8 of the input mel as tokens. mel_dim=80 self.model_dim = model_dim self.max_mel_frames = max_mel_frames - self.text_embedding = nn.Embedding(number_symbols, model_dim) - # Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias - # into the model which we cannot provide in inference. - self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=1, convnd=nn.Conv1d), - PixelUnshuffle1D(2), - ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d), - ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d)) + self.text_embedding = nn.Embedding(self.NUMBER_SYMBOLS, model_dim) + self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim) # *_tags are additively applied to - self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) - self.separator = nn.Parameter(torch.randn(1, 1, model_dim)) - self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) - self.text_preprocess_xformer = GPT(GPTConfig(max_symbols_per_phrase, n_layer=2, n_head=2, n_embd=model_dim)) - self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8)) + self.text_pos_embedding = nn.Embedding(max_symbols_per_phrase, model_dim) + self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim) + self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False) - self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d), - PixelShuffle1D(2), - ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d), - ConvGnSilu(model_dim//2, 1, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d)) - self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d), - PixelShuffle1D(2), - ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d), - ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d)) - #self.postnet = Postnet(munchify(hparams.create_hparams())) + self.final_norm = nn.LayerNorm(model_dim) + self.text_head = nn.Linear(model_dim, self.NUMBER_SYMBOLS) + self.mel_head = nn.Linear(model_dim, self.MEL_DICTIONARY_SIZE) - def forward(self, text_inputs, mel_targets, output_lengths): - # Pad mel_targets to be a multiple of 2 - padded = mel_targets.shape[-1] % 2 != 0 - if padded: - mel_targets = F.pad(mel_targets, (0,1)) + def forward(self, text_inputs, text_lengths, mel_targets, output_lengths): + output_lengths = output_lengths * 3 // 8 # The data we are dealing with has been compressed by the vqvae. + # Add the stop tokens to the end of the texts and mels. Theoretically this would be better done at the dataloader level. + batch_range = torch.arange(0, text_inputs.shape[0]) + text_inputs = F.pad(text_inputs, (0,1)) + text_inputs.index_put_((batch_range, text_lengths), torch.tensor([self.TEXT_STOP_TOKEN], dtype=torch.long, device=text_inputs.device)) + text_lengths = text_lengths + 1 + mel_targets = F.pad(mel_targets, (0,1)) + mel_targets.index_put_((batch_range, output_lengths), torch.tensor([self.MEL_STOP_TOKEN], dtype=torch.long, device=text_inputs.device)) + output_lengths = output_lengths + 1 + # Add the start tokens to the beginnings of the texts and mels. + text_inputs = F.pad(text_inputs, (1,0), value=self.TEXT_START_TOKEN) + text_lengths = text_lengths + 1 + mel_targets = F.pad(mel_targets, (1,0), value=self.MEL_START_TOKEN) + output_lengths = output_lengths + 1 + # Add padding as well. This also should realistically be done at the dataloader level. + text_pad_mask = ~get_mask_from_lengths(text_lengths, text_inputs.shape[1]) + text_inputs.data.masked_fill_(text_pad_mask, self.TEXT_PAD_TOKEN) + mel_pad_mask = ~get_mask_from_lengths(output_lengths, mel_targets.shape[1]) + mel_targets.data.masked_fill_(mel_pad_mask, self.MEL_PAD_TOKEN) text_emb = self.text_embedding(text_inputs) - text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1]) - text_emb = text_emb + self.text_tags - mel_emb = self.mel_encoder(mel_targets).permute(0,2,1) - mel_emb = mel_emb + self.audio_tags - emb = torch.cat([text_emb, - self.separator.repeat(text_emb.shape[0],1,1), - mel_emb], dim=1) - enc = self.gpt(emb, text_emb.shape[1]) - mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1) - gates = self.gate_head(mel_portion).squeeze(1) - mel_pred = self.mel_head(mel_portion) + text_emb = text_emb + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device)) + mel_emb = self.mel_embedding(mel_targets) + mel_emb = mel_emb + self.mel_pos_embedding(torch.arange(mel_targets.shape[1], device=mel_targets.device)) + emb = torch.cat([text_emb, mel_emb], dim=1) + enc = self.gpt(emb) - # Mask portions of output which we don't need to predict. - mask = ~get_mask_from_lengths(output_lengths, mel_pred.shape[-1]) - mask = mask.unsqueeze(1).repeat(1, mel_pred.shape[1], 1) - mel_pred.data.masked_fill_(mask, 0) - gates.data.masked_fill_(mask[:, 0, :], 1e3) + # Compute logits for text and mel heads + text_logits = self.final_norm(enc[:, :text_emb.shape[1]]) + text_logits = self.text_head(text_logits) + mel_logits = self.final_norm(enc[:, text_emb.shape[1]:]) + mel_logits = self.mel_head(mel_logits) - if padded: - mel_pred = mel_pred[:, :, :-1] - gates = gates[:, :-1] + # Compute loss + loss_text = F.cross_entropy(text_logits.permute(0,2,1)[:,:,1:], text_inputs[:,1:], reduction='none') + loss_mel = F.cross_entropy(mel_logits.permute(0,2,1)[:,:,1:], mel_targets[:,1:], reduction='none') + # Apply a reduction factor across MEL_PAD and TEXT_PAD tokens. + pad_loss_reduction_factor = .01 + loss_text = loss_text * torch.ones_like(loss_text).masked_fill_(text_pad_mask[:,1:], pad_loss_reduction_factor) + loss_mel = loss_mel * torch.ones_like(loss_mel).masked_fill_(mel_pad_mask[:,1:], pad_loss_reduction_factor) - #postnet_mel_pred = self.postnet(mel_pred) - #return mel_pred, postnet_mel_pred, gates - return mel_pred, gates + # Fix up mel_logits so it can go into a VAE decoder as well. + mel_codes = torch.argmax(F.softmax(mel_logits, dim=-1), dim=-1) + mel_codes = mel_codes[:,1:] + mel_codes = mel_codes * torch.ones_like(mel_codes).masked_fill_(mel_pad_mask[:,1:], 0) + mel_codes = mel_codes[:,:-1] + extra_mask = mel_codes < self.MEL_DICTIONARY_SIZE-3 # The VAE doesn't know about START/STOP/PAD + mel_codes = mel_codes * extra_mask - def test_guide(self, mel_guide, amount=50): - mel_guide = mel_guide[:,:,:amount] - mel_emb = self.mel_encoder(mel_guide).permute(0,2,1) - mel_emb = mel_emb + self.audio_tags - return mel_emb + return loss_text.mean(), loss_mel.mean(), mel_codes def inference(self, text_inputs, mel_guide): MEL_HEAD_EXPANSION = 2 @@ -138,12 +146,11 @@ def register_gpt_tts(opt_net, opt): if __name__ == '__main__': gpt = GptTts() - m, g = gpt(torch.randint(high=24, size=(2,60)), - torch.randn(2,80,747), - torch.tensor([600,747])) - print(m.shape) - #print(p.shape) - print(g.shape) + l1, l2, i = gpt(torch.randint(high=24, size=(2,60)), + torch.tensor([55,58]), + torch.randint(high=512, size=(2,310)), + torch.tensor([300,305])) + print(i.shape) #o = gpt.infer(torch.randint(high=24, size=(2,60))) #print(o.shape) diff --git a/codes/models/gpt_voice/min_gpt.py b/codes/models/gpt_voice/min_gpt.py index 00e48b6e..209cc15e 100644 --- a/codes/models/gpt_voice/min_gpt.py +++ b/codes/models/gpt_voice/min_gpt.py @@ -16,6 +16,8 @@ import torch import torch.nn as nn from torch.nn import functional as F +from utils.util import checkpoint, sequential_checkpoint + logger = logging.getLogger(__name__) class GPTConfig: @@ -56,7 +58,7 @@ class CausalSelfAttention(nn.Module): .view(1, 1, config.block_size, config.block_size)) self.n_head = config.n_head - def forward(self, x, text_block_size): + def forward(self, x): B, T, C = x.size() # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -66,12 +68,10 @@ class CausalSelfAttention(nn.Module): # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.mask[:,:,:T,:T].logical_or( - F.pad(torch.ones((B,self.n_head,text_block_size,text_block_size), device=x.device), (0, T-text_block_size, 0, T-text_block_size))) == 0, - float('-inf')) + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_drop(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection @@ -93,19 +93,22 @@ class Block(nn.Module): nn.Dropout(config.resid_pdrop), ) - def forward(self, x, text_block_size): - x = x + self.attn(self.ln1(x), text_block_size) + def forward(self, x): + x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class GPT(nn.Module): """ the full GPT language model, with a context size of block_size """ - def __init__(self, config): + def __init__(self, config, do_pos_emb=True): super().__init__() # input embedding stem - self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + if do_pos_emb: + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + else: + self.pos_emb = None self.drop = nn.Dropout(config.embd_pdrop) # transformer self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) @@ -173,14 +176,14 @@ class GPT(nn.Module): optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) return optimizer - def forward(self, embeddings, text_block_sizes): + def forward(self, embeddings): b, t, c = embeddings.size() assert t <= self.block_size, "Cannot forward, model block size is exhausted." # forward the GPT model - position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector - x = self.drop(embeddings + position_embeddings) - for block in self.blocks: - x = block(x, text_block_sizes) + if self.pos_emb is not None: + embeddings = embeddings + self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(embeddings) + x = sequential_checkpoint(self.blocks, 4, x) return x \ No newline at end of file diff --git a/codes/models/vqvae/vqvae.py b/codes/models/vqvae/vqvae.py index 012864f3..75f91b8b 100644 --- a/codes/models/vqvae/vqvae.py +++ b/codes/models/vqvae/vqvae.py @@ -236,6 +236,16 @@ class VQVAE(nn.Module): return quant_t, quant_b, diff_t + diff_b, id_t, id_b + def encode_only_quantized(self, input): + qt, qb, d, idt, idb = self.encode(input) + # Interleave top and bottom so top comes first and bottom comes second, such that the output looks like + # [t0,b0,b1,t1,b1,b2,t2,b3,b4....] + b, s = idt.shape + idt = idt.view(b, s, 1) + idb = idb.reshape(b, 2, s).permute(0,2,1).contiguous() + ids = torch.cat([idt, idb], dim=2).reshape(b, s*3) + return ids + def decode(self, quant_t, quant_b): upsample_t = self.upsample_t(quant_t) quant = torch.cat([upsample_t, quant_b], 1) @@ -245,14 +255,25 @@ class VQVAE(nn.Module): def decode_code(self, code_t, code_b): quant_t = self.quantize_t.embed_code(code_t) - quant_t = quant_t.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) + quant_t = quant_t.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1)) quant_b = self.quantize_b.embed_code(code_b) - quant_b = quant_b.permute((0,3,1,2) if len(input) == 4 else (0,2,1)) + quant_b = quant_b.permute((0,3,1,2) if len(code_t.shape) == 4 else (0,2,1)) dec = self.decode(quant_t, quant_b) return dec + # Performs decode_code() with the outputs from encode_only_quantized. + def decode_code_joined(self, input): + b, s = input.shape + assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized. + s = s // 3 + + input = input.reshape(b, s, 3).permute(0,2,1).contiguous() + t = input[:,0,:] + b = input[:,1:,:].reshape(b, 2*s) + return self.decode_code(t, b) + @register_model def register_vqvae(opt_net, opt): @@ -272,5 +293,7 @@ def register_vqvae_audio(opt_net, opt): if __name__ == '__main__': model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d) - res=model(torch.randn(1,80,224)) + #res=model(torch.randn(1,80,2048)) + e = model.encode_only_quantized(torch.randn(1, 80, 2048)) + model.decode_code_joined(e) print(res[0].shape) \ No newline at end of file diff --git a/codes/scripts/audio/test_audio_gen.py b/codes/scripts/audio/test_audio_gen.py index fd3c97f8..ce441c11 100644 --- a/codes/scripts/audio/test_audio_gen.py +++ b/codes/scripts/audio/test_audio_gen.py @@ -51,7 +51,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt diff --git a/codes/train.py b/codes/train.py index 35e6c88f..da87349b 100644 --- a/codes/train.py +++ b/codes/train.py @@ -300,7 +300,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_audio_lj.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 2d82f5fa..90f5a063 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -81,7 +81,9 @@ class ConfigurableStep(Module): norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d, nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm) emb_modules = (nn.Embedding, nn.EmbeddingBag) - params_notweights = set() + param_names_notweights = set() + all_param_names = set() + param_map = {} for mn, m in net.named_modules(): for k, v in m.named_parameters(): v.is_bias = k.endswith(".bias") @@ -89,8 +91,11 @@ class ConfigurableStep(Module): v.is_norm = isinstance(m, norm_modules) v.is_emb = isinstance(m, emb_modules) + fpn = '%s.%s' % (mn, k) if mn else k # full param name + all_param_names.add(fpn) + param_map[fpn] = v if v.is_bias or v.is_norm or v.is_emb: - params_notweights.add(v) + param_names_notweights.add(fpn) # Some models can specify some parameters to be in different groups. param_group = "default" @@ -106,7 +111,8 @@ class ConfigurableStep(Module): else: if self.env['rank'] <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) - params_weights = set(net.parameters()) ^ params_notweights + params_notweights = [param_map[k] for k in sorted(list(param_names_notweights))] + params_weights = [param_map[k] for k in sorted(list(all_param_names ^ param_names_notweights))] if 'optimizer' not in self.step_opt.keys() or self.step_opt['optimizer'] == 'adam': opt = torch.optim.Adam(list(optim_params.values()), lr=opt_config['lr'], @@ -114,8 +120,8 @@ class ConfigurableStep(Module): betas=(opt_config['beta1'], opt_config['beta2'])) elif self.step_opt['optimizer'] == 'adamw': groups = [ - { 'params': list(params_weights), 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) }, - { 'params': list(params_notweights), 'weight_decay': 0 } + { 'params': params_weights, 'weight_decay': opt_get(opt_config, ['weight_decay'], 0) }, + { 'params': params_notweights, 'weight_decay': 0 } ] opt = torch.optim.AdamW(groups, lr=opt_config['lr'], weight_decay=opt_get(opt_config, ['weight_decay'], 1e-2), @@ -190,6 +196,9 @@ class ConfigurableStep(Module): if no_ddp_sync and hasattr(training_net, 'no_sync'): with training_net.no_sync(): injected = inj(local_state) + elif opt_get(inj.opt, ['no_grad'], False): + with torch.no_grad(): + injected = inj(local_state) else: injected = inj(local_state) local_state.update(injected)