Rework how DVAE tokens are ordered

It might make more sense to have top tokens, then bottom tokens
with top tokens having different discretized values.
This commit is contained in:
James Betker 2021-08-05 07:07:17 -06:00
parent 4017236ba9
commit c0f61a2e15
4 changed files with 15 additions and 16 deletions

View File

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from models.gpt_voice.lucidrains_gpt import Transformer
from models.gpt_voice.min_gpt import GPT, GPTConfig
from models.tacotron2.taco_utils import get_mask_from_lengths
from models.tacotron2.text import symbols
from trainer.networks import register_model
@ -12,7 +13,7 @@ class GptTts(nn.Module):
MAX_SYMBOLS_PER_PHRASE = 200
NUMBER_SYMBOLS = len(symbols)
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
MEL_DICTIONARY_SIZE = 512+3
MEL_DICTIONARY_SIZE = 1024+3
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
@ -27,8 +28,8 @@ class GptTts(nn.Module):
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
#self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
self.gpt = GPT(GPTConfig(1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, n_layer=8, n_embd=model_dim, n_head=8), do_pos_emb=False)
#self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
self.final_norm = nn.LayerNorm(model_dim)
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)

View File

@ -184,6 +184,7 @@ class VQVAE(nn.Module):
self.unsqueeze_channels = in_channel == -1
in_channel = abs(in_channel)
self.codebook_size = codebook_size
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module)
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module)
self.quantize_conv_t = conv_module(channel, codebook_dim, 1)
@ -238,12 +239,9 @@ class VQVAE(nn.Module):
def encode_only_quantized(self, input):
qt, qb, d, idt, idb = self.encode(input)
# Interleave top and bottom so top comes first and bottom comes second, such that the output looks like
# [t0,b0,b1,t1,b1,b2,t2,b3,b4....]
b, s = idt.shape
idt = idt.view(b, s, 1)
idb = idb.reshape(b, 2, s).permute(0,2,1).contiguous()
ids = torch.cat([idt, idb], dim=2).reshape(b, s*3)
# Append top and bottom into the same sequence, adding the codebook length onto the top to discriminate it.
idt += self.codebook_size
ids = torch.cat([idt, idb], dim=1)
return ids
def decode(self, quant_t, quant_b):
@ -269,9 +267,9 @@ class VQVAE(nn.Module):
assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized.
s = s // 3
input = input.reshape(b, s, 3).permute(0,2,1).contiguous()
t = input[:,0,:]
b = input[:,1:,:].reshape(b, 2*s)
# This doesn't work with batching. TODO: fixme.
t = input[:,:s] - self.codebook_size
b = input[:,s:]
return self.decode_code(t, b)
@ -295,5 +293,5 @@ if __name__ == '__main__':
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
#res=model(torch.randn(1,80,2048))
e = model.encode_only_quantized(torch.randn(1, 80, 2048))
model.decode_code_joined(e)
print(res[0].shape)
k = model.decode_code_joined(e)
print(k.shape)

View File

@ -51,7 +51,7 @@ if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
want_metrics = False
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)
utils.util.loaded_options = opt

View File

@ -300,7 +300,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_audio_clips.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()