From 0c9e75bc69af25c61b6b2afa53866d2a3f35d788 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 31 Jul 2021 15:57:57 -0600 Subject: [PATCH] Improvements to GptTts --- codes/models/gpt_voice/gpt_tts.py | 58 +++++++++++------------ codes/models/gpt_voice/min_gpt.py | 17 ++++--- codes/models/gpt_voice/pixelshuffle_1d.py | 49 +++++++++++++++++++ codes/scripts/audio/test_audio_gen.py | 11 +++++ codes/trainer/optimizers/sgd.py | 2 +- 5 files changed, 100 insertions(+), 37 deletions(-) create mode 100644 codes/models/gpt_voice/pixelshuffle_1d.py diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index 9b680fc0..27e9c1e2 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -1,27 +1,19 @@ import torch import torch.nn as nn import torch.nn.functional as F +from munch import munchify from tqdm import tqdm from models.arch_util import ConvGnSilu +from models.gpt_voice.pixelshuffle_1d import PixelUnshuffle1D, PixelShuffle1D +from models.tacotron2 import hparams from models.tacotron2.taco_utils import get_mask_from_lengths +from models.tacotron2.tacotron2 import Postnet from models.tacotron2.text import symbols from models.gpt_voice.min_gpt import GPT, GPTConfig from trainer.networks import register_model -# A Conv1d that masks out kernel elements ahead of the current location. -class CausalConv1d(nn.Conv1d): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.kernel_mask = torch.ones_like(self.weight) - self.kernel_mask[:, :, -(self.kernel_size[0]//2):] = 0 - - def forward(self, input): - self.kernel_mask = self.kernel_mask.to(input.device) - return self._conv_forward(input, self.weight * self.kernel_mask, self.bias) - - class GptTts(nn.Module): def __init__(self): super().__init__() @@ -36,24 +28,26 @@ class GptTts(nn.Module): self.text_embedding = nn.Embedding(number_symbols, model_dim) # Whenever we process MEL frames, we need to be careful to use casually masked convolutions to avoid adding bias # into the model which we cannot provide in inference. - self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d), - ConvGnSilu(model_dim//2, model_dim, kernel_size=5, stride=2, convnd=CausalConv1d)) + self.mel_encoder = nn.Sequential(ConvGnSilu(mel_dim, model_dim//2, kernel_size=1, convnd=nn.Conv1d), + PixelUnshuffle1D(2), + ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d), + ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d)) # *_tags are additively applied to self.text_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) self.separator = nn.Parameter(torch.randn(1, 1, model_dim)) self.audio_tags = nn.Parameter(torch.randn(1, 1, model_dim)/256.0) + self.text_preprocess_xformer = GPT(GPTConfig(max_symbols_per_phrase, n_layer=2, n_head=2, n_embd=model_dim)) self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames//2, n_embd=model_dim, n_head=8)) - self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d), - nn.Upsample(scale_factor=2, mode='nearest'), - ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d), - # No need for causal convolutions when kernel_size=1 - nn.Conv1d(model_dim//2, 1, kernel_size=1)) - self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=5, convnd=CausalConv1d), - nn.Upsample(scale_factor=2, mode='nearest'), - ConvGnSilu(model_dim, model_dim//2, kernel_size=5, convnd=CausalConv1d), - ConvGnSilu(model_dim//2, model_dim//2, kernel_size=5, convnd=CausalConv1d), - ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, activation=False, norm=False, convnd=nn.Conv1d)) + self.gate_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d), + PixelShuffle1D(2), + ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d), + ConvGnSilu(model_dim//2, 1, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d)) + self.mel_head = nn.Sequential(ConvGnSilu(model_dim, model_dim, kernel_size=1, convnd=nn.Conv1d), + PixelShuffle1D(2), + ConvGnSilu(model_dim//2, model_dim//2, kernel_size=1, convnd=nn.Conv1d), + ConvGnSilu(model_dim//2, mel_dim, kernel_size=1, norm=False, activation=False, convnd=nn.Conv1d)) + #self.postnet = Postnet(munchify(hparams.create_hparams())) def forward(self, text_inputs, mel_targets, output_lengths): # Pad mel_targets to be a multiple of 2 @@ -62,13 +56,14 @@ class GptTts(nn.Module): mel_targets = F.pad(mel_targets, (0,1)) text_emb = self.text_embedding(text_inputs) + text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1]) text_emb = text_emb + self.text_tags mel_emb = self.mel_encoder(mel_targets).permute(0,2,1) mel_emb = mel_emb + self.audio_tags emb = torch.cat([text_emb, self.separator.repeat(text_emb.shape[0],1,1), mel_emb], dim=1) - enc = self.gpt(emb) + enc = self.gpt(emb, text_emb.shape[1]) mel_portion = enc[:, text_emb.shape[1]+1:].permute(0,2,1) gates = self.gate_head(mel_portion).squeeze(1) mel_pred = self.mel_head(mel_portion) @@ -82,6 +77,9 @@ class GptTts(nn.Module): if padded: mel_pred = mel_pred[:, :, :-1] gates = gates[:, :-1] + + #postnet_mel_pred = self.postnet(mel_pred) + #return mel_pred, postnet_mel_pred, gates return mel_pred, gates def test_guide(self, mel_guide, amount=50): @@ -95,15 +93,16 @@ class GptTts(nn.Module): GATE_THRESHOLD = .95 text_emb = self.text_embedding(text_inputs) + text_emb = self.text_preprocess_xformer(text_emb, text_emb.shape[1]) text_emb = text_emb + self.text_tags b,s,c = text_emb.shape emb = torch.cat([text_emb, - self.separator.repeat(text_emb.shape[0],1,1)], dim=1) + self.separator.repeat(text_emb.shape[0],1,1),], dim=1) #self.test_guide(mel_guide)], dim=1) completed = torch.zeros((b,), device=text_inputs.device, dtype=torch.bool) output = None for i in tqdm(range(self.max_mel_frames)): - enc = self.gpt(emb) + enc = self.gpt(emb, text_emb.shape[1]) inferred = enc[:,s:,:].permute(0,2,1) # Create output frames. inferred_mel_frame = self.mel_head(inferred)[:,:,-MEL_HEAD_EXPANSION:] @@ -143,9 +142,10 @@ if __name__ == '__main__': torch.randn(2,80,747), torch.tensor([600,747])) print(m.shape) + #print(p.shape) print(g.shape) - o = gpt.infer(torch.randint(high=24, size=(2,60))) - print(o.shape) + #o = gpt.infer(torch.randint(high=24, size=(2,60))) + #print(o.shape) diff --git a/codes/models/gpt_voice/min_gpt.py b/codes/models/gpt_voice/min_gpt.py index 19a5189b..00e48b6e 100644 --- a/codes/models/gpt_voice/min_gpt.py +++ b/codes/models/gpt_voice/min_gpt.py @@ -56,7 +56,7 @@ class CausalSelfAttention(nn.Module): .view(1, 1, config.block_size, config.block_size)) self.n_head = config.n_head - def forward(self, x, layer_past=None): + def forward(self, x, text_block_size): B, T, C = x.size() # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -66,10 +66,12 @@ class CausalSelfAttention(nn.Module): # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.mask[:,:,:T,:T].logical_or( + F.pad(torch.ones((B,self.n_head,text_block_size,text_block_size), device=x.device), (0, T-text_block_size, 0, T-text_block_size))) == 0, + float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_drop(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection @@ -91,8 +93,8 @@ class Block(nn.Module): nn.Dropout(config.resid_pdrop), ) - def forward(self, x): - x = x + self.attn(self.ln1(x)) + def forward(self, x, text_block_size): + x = x + self.attn(self.ln1(x), text_block_size) x = x + self.mlp(self.ln2(x)) return x @@ -171,13 +173,14 @@ class GPT(nn.Module): optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) return optimizer - def forward(self, embeddings): + def forward(self, embeddings, text_block_sizes): b, t, c = embeddings.size() assert t <= self.block_size, "Cannot forward, model block size is exhausted." # forward the GPT model position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector x = self.drop(embeddings + position_embeddings) - x = self.blocks(x) + for block in self.blocks: + x = block(x, text_block_sizes) return x \ No newline at end of file diff --git a/codes/models/gpt_voice/pixelshuffle_1d.py b/codes/models/gpt_voice/pixelshuffle_1d.py new file mode 100644 index 00000000..4ff48904 --- /dev/null +++ b/codes/models/gpt_voice/pixelshuffle_1d.py @@ -0,0 +1,49 @@ +import torch + +# "long" and "short" denote longer and shorter samples +class PixelShuffle1D(torch.nn.Module): + """ + 1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf + Upscales sample length, downscales channel length + "short" is input, "long" is output + """ + def __init__(self, upscale_factor): + super(PixelShuffle1D, self).__init__() + self.upscale_factor = upscale_factor + + def forward(self, x): + batch_size = x.shape[0] + short_channel_len = x.shape[1] + short_width = x.shape[2] + + long_channel_len = short_channel_len // self.upscale_factor + long_width = self.upscale_factor * short_width + + x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width]) + x = x.permute(0, 2, 3, 1).contiguous() + x = x.view(batch_size, long_channel_len, long_width) + + return x + +class PixelUnshuffle1D(torch.nn.Module): + """ + Inverse of 1D pixel shuffler + Upscales channel length, downscales sample length + "long" is input, "short" is output + """ + def __init__(self, downscale_factor): + super(PixelUnshuffle1D, self).__init__() + self.downscale_factor = downscale_factor + + def forward(self, x): + batch_size = x.shape[0] + long_channel_len = x.shape[1] + long_width = x.shape[2] + + short_channel_len = long_channel_len * self.downscale_factor + short_width = long_width // self.downscale_factor + + x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor]) + x = x.permute(0, 3, 1, 2).contiguous() + x = x.view([batch_size, short_channel_len, short_width]) + return x \ No newline at end of file diff --git a/codes/scripts/audio/test_audio_gen.py b/codes/scripts/audio/test_audio_gen.py index f67ec2e6..fd3c97f8 100644 --- a/codes/scripts/audio/test_audio_gen.py +++ b/codes/scripts/audio/test_audio_gen.py @@ -3,6 +3,8 @@ import logging import random import argparse +import torchvision + import utils import utils.options as option import utils.util as util @@ -19,11 +21,20 @@ def forward_pass(model, denoiser, data, output_dir, opt, b): with torch.no_grad(): model.feed_data(data, 0) model.test() + pred_waveforms = model.eval_state[opt['eval']['output_state']][0] pred_waveforms = denoiser(pred_waveforms) ground_truth_waveforms = model.eval_state[opt['eval']['ground_truth']][0] ground_truth_waveforms = denoiser(ground_truth_waveforms) for i in range(pred_waveforms.shape[0]): + # Output predicted mels and waveforms. + pred_mel = model.eval_state[opt['eval']['pred_mel']][i] + pred_mel = ((pred_mel - pred_mel.mean()) / max(abs(pred_mel.min()), pred_mel.max())).unsqueeze(1) + torchvision.utils.save_image(pred_mel, osp.join(output_dir, f'{b}_{i}_pred_mel.png')) + gt_mel = model.eval_state[opt['eval']['ground_truth_mel']][i] + gt_mel = ((gt_mel - gt_mel.mean()) / max(abs(gt_mel.min()), gt_mel.max())).unsqueeze(1) + torchvision.utils.save_image(gt_mel, osp.join(output_dir, f'{b}_{i}_gt_mel.png')) + audio = pred_waveforms[i][0].cpu().numpy() wavfile.write(osp.join(output_dir, f'{b}_{i}.wav'), 22050, audio) audio = ground_truth_waveforms[i][0].cpu().numpy() diff --git a/codes/trainer/optimizers/sgd.py b/codes/trainer/optimizers/sgd.py index f82bf33c..a55f0d7a 100644 --- a/codes/trainer/optimizers/sgd.py +++ b/codes/trainer/optimizers/sgd.py @@ -54,7 +54,7 @@ class SGDNoBiasMomentum(Optimizer): if weight_decay != 0: d_p = d_p.add(p, alpha=weight_decay) # **this is the only modification over standard torch.optim.SGD: - is_bn_or_bias = (hasattr(p, 'is_bn') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias) + is_bn_or_bias = (hasattr(p, 'is_norm') and p.is_bn) or (hasattr(p, 'is_bias') and p.is_bias) if not is_bn_or_bias and momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: