Rework how DVAE tokens are ordered

It might make more sense to have top tokens, then bottom tokens
with top tokens having different discretized values.
This commit is contained in:
James Betker 2021-08-05 07:07:17 -06:00
parent 4017236ba9
commit c0f61a2e15
4 changed files with 15 additions and 16 deletions

View File

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from models.gpt_voice.lucidrains_gpt import Transformer from models.gpt_voice.lucidrains_gpt import Transformer
from models.gpt_voice.min_gpt import GPT, GPTConfig
from models.tacotron2.taco_utils import get_mask_from_lengths from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols from models.tacotron2.text import symbols
from trainer.networks import register_model from trainer.networks import register_model
@ -12,7 +13,7 @@ class GptTts(nn.Module):
MAX_SYMBOLS_PER_PHRASE = 200 MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len(symbols) NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2 NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
MEL_DICTIONARY_SIZE = 512+3 MEL_DICTIONARY_SIZE = 1024+3
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3 MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2 MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
@ -27,8 +28,8 @@ class GptTts(nn.Module):
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim) self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim) self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim) self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
#self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False) self.gpt = GPT(GPTConfig(1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, n_layer=8, n_embd=model_dim, n_head=8), do_pos_emb=False)
self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8) #self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
self.final_norm = nn.LayerNorm(model_dim) self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)

View File

@ -184,6 +184,7 @@ class VQVAE(nn.Module):
self.unsqueeze_channels = in_channel == -1 self.unsqueeze_channels = in_channel == -1
in_channel = abs(in_channel) in_channel = abs(in_channel)
self.codebook_size = codebook_size
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module) self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module) self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module)
self.quantize_conv_t = conv_module(channel, codebook_dim, 1) self.quantize_conv_t = conv_module(channel, codebook_dim, 1)
@ -238,12 +239,9 @@ class VQVAE(nn.Module):
def encode_only_quantized(self, input): def encode_only_quantized(self, input):
qt, qb, d, idt, idb = self.encode(input) qt, qb, d, idt, idb = self.encode(input)
# Interleave top and bottom so top comes first and bottom comes second, such that the output looks like # Append top and bottom into the same sequence, adding the codebook length onto the top to discriminate it.
# [t0,b0,b1,t1,b1,b2,t2,b3,b4....] idt += self.codebook_size
b, s = idt.shape ids = torch.cat([idt, idb], dim=1)
idt = idt.view(b, s, 1)
idb = idb.reshape(b, 2, s).permute(0,2,1).contiguous()
ids = torch.cat([idt, idb], dim=2).reshape(b, s*3)
return ids return ids
def decode(self, quant_t, quant_b): def decode(self, quant_t, quant_b):
@ -269,9 +267,9 @@ class VQVAE(nn.Module):
assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized. assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized.
s = s // 3 s = s // 3
input = input.reshape(b, s, 3).permute(0,2,1).contiguous() # This doesn't work with batching. TODO: fixme.
t = input[:,0,:] t = input[:,:s] - self.codebook_size
b = input[:,1:,:].reshape(b, 2*s) b = input[:,s:]
return self.decode_code(t, b) return self.decode_code(t, b)
@ -295,5 +293,5 @@ if __name__ == '__main__':
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d) model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
#res=model(torch.randn(1,80,2048)) #res=model(torch.randn(1,80,2048))
e = model.encode_only_quantized(torch.randn(1, 80, 2048)) e = model.encode_only_quantized(torch.randn(1, 80, 2048))
model.decode_code_joined(e) k = model.decode_code_joined(e)
print(res[0].shape) print(k.shape)

View File

@ -51,7 +51,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
want_metrics = False want_metrics = False
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt) opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt utils.util.loaded_options = opt

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_audio_clips.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()