From c0f61a2e151e026a0b82463f0d1f099b9f7d6739 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 5 Aug 2021 07:07:17 -0600 Subject: [PATCH] Rework how DVAE tokens are ordered It might make more sense to have top tokens, then bottom tokens with top tokens having different discretized values. --- codes/models/gpt_voice/gpt_tts.py | 7 ++++--- codes/models/vqvae/vqvae.py | 20 +++++++++----------- codes/scripts/audio/test_audio_gen.py | 2 +- codes/train.py | 2 +- 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index 2c678f07..74fefed2 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from models.gpt_voice.lucidrains_gpt import Transformer +from models.gpt_voice.min_gpt import GPT, GPTConfig from models.tacotron2.taco_utils import get_mask_from_lengths from models.tacotron2.text import symbols from trainer.networks import register_model @@ -12,7 +13,7 @@ class GptTts(nn.Module): MAX_SYMBOLS_PER_PHRASE = 200 NUMBER_SYMBOLS = len(symbols) NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2 - MEL_DICTIONARY_SIZE = 512+3 + MEL_DICTIONARY_SIZE = 1024+3 MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3 MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2 @@ -27,8 +28,8 @@ class GptTts(nn.Module): self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim) self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim) self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim) - #self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False) - self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8) + self.gpt = GPT(GPTConfig(1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, n_layer=8, n_embd=model_dim, n_head=8), do_pos_emb=False) + #self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8) self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) diff --git a/codes/models/vqvae/vqvae.py b/codes/models/vqvae/vqvae.py index 5b7cc09c..5b3649d3 100644 --- a/codes/models/vqvae/vqvae.py +++ b/codes/models/vqvae/vqvae.py @@ -184,6 +184,7 @@ class VQVAE(nn.Module): self.unsqueeze_channels = in_channel == -1 in_channel = abs(in_channel) + self.codebook_size = codebook_size self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module) self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module) self.quantize_conv_t = conv_module(channel, codebook_dim, 1) @@ -238,12 +239,9 @@ class VQVAE(nn.Module): def encode_only_quantized(self, input): qt, qb, d, idt, idb = self.encode(input) - # Interleave top and bottom so top comes first and bottom comes second, such that the output looks like - # [t0,b0,b1,t1,b1,b2,t2,b3,b4....] - b, s = idt.shape - idt = idt.view(b, s, 1) - idb = idb.reshape(b, 2, s).permute(0,2,1).contiguous() - ids = torch.cat([idt, idb], dim=2).reshape(b, s*3) + # Append top and bottom into the same sequence, adding the codebook length onto the top to discriminate it. + idt += self.codebook_size + ids = torch.cat([idt, idb], dim=1) return ids def decode(self, quant_t, quant_b): @@ -269,9 +267,9 @@ class VQVAE(nn.Module): assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized. s = s // 3 - input = input.reshape(b, s, 3).permute(0,2,1).contiguous() - t = input[:,0,:] - b = input[:,1:,:].reshape(b, 2*s) + # This doesn't work with batching. TODO: fixme. + t = input[:,:s] - self.codebook_size + b = input[:,s:] return self.decode_code(t, b) @@ -295,5 +293,5 @@ if __name__ == '__main__': model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d) #res=model(torch.randn(1,80,2048)) e = model.encode_only_quantized(torch.randn(1, 80, 2048)) - model.decode_code_joined(e) - print(res[0].shape) \ No newline at end of file + k = model.decode_code_joined(e) + print(k.shape) \ No newline at end of file diff --git a/codes/scripts/audio/test_audio_gen.py b/codes/scripts/audio/test_audio_gen.py index fd3c97f8..ce441c11 100644 --- a/codes/scripts/audio/test_audio_gen.py +++ b/codes/scripts/audio/test_audio_gen.py @@ -51,7 +51,7 @@ if __name__ == "__main__": torch.backends.cudnn.benchmark = True want_metrics = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) utils.util.loaded_options = opt diff --git a/codes/train.py b/codes/train.py index 03513ad2..da87349b 100644 --- a/codes/train.py +++ b/codes/train.py @@ -300,7 +300,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_audio_clips.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()