Rework how DVAE tokens are ordered
It might make more sense to have top tokens, then bottom tokens with top tokens having different discretized values.
This commit is contained in:
parent
4017236ba9
commit
c0f61a2e15
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
||||
from models.gpt_voice.min_gpt import GPT, GPTConfig
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from models.tacotron2.text import symbols
|
||||
from trainer.networks import register_model
|
||||
|
@ -12,7 +13,7 @@ class GptTts(nn.Module):
|
|||
MAX_SYMBOLS_PER_PHRASE = 200
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
|
||||
MEL_DICTIONARY_SIZE = 512+3
|
||||
MEL_DICTIONARY_SIZE = 1024+3
|
||||
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
|
||||
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
|
||||
|
||||
|
@ -27,8 +28,8 @@ class GptTts(nn.Module):
|
|||
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
|
||||
#self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
|
||||
self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
|
||||
self.gpt = GPT(GPTConfig(1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, n_layer=8, n_embd=model_dim, n_head=8), do_pos_emb=False)
|
||||
#self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
|
|
|
@ -184,6 +184,7 @@ class VQVAE(nn.Module):
|
|||
self.unsqueeze_channels = in_channel == -1
|
||||
in_channel = abs(in_channel)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module)
|
||||
self.quantize_conv_t = conv_module(channel, codebook_dim, 1)
|
||||
|
@ -238,12 +239,9 @@ class VQVAE(nn.Module):
|
|||
|
||||
def encode_only_quantized(self, input):
|
||||
qt, qb, d, idt, idb = self.encode(input)
|
||||
# Interleave top and bottom so top comes first and bottom comes second, such that the output looks like
|
||||
# [t0,b0,b1,t1,b1,b2,t2,b3,b4....]
|
||||
b, s = idt.shape
|
||||
idt = idt.view(b, s, 1)
|
||||
idb = idb.reshape(b, 2, s).permute(0,2,1).contiguous()
|
||||
ids = torch.cat([idt, idb], dim=2).reshape(b, s*3)
|
||||
# Append top and bottom into the same sequence, adding the codebook length onto the top to discriminate it.
|
||||
idt += self.codebook_size
|
||||
ids = torch.cat([idt, idb], dim=1)
|
||||
return ids
|
||||
|
||||
def decode(self, quant_t, quant_b):
|
||||
|
@ -269,9 +267,9 @@ class VQVAE(nn.Module):
|
|||
assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized.
|
||||
s = s // 3
|
||||
|
||||
input = input.reshape(b, s, 3).permute(0,2,1).contiguous()
|
||||
t = input[:,0,:]
|
||||
b = input[:,1:,:].reshape(b, 2*s)
|
||||
# This doesn't work with batching. TODO: fixme.
|
||||
t = input[:,:s] - self.codebook_size
|
||||
b = input[:,s:]
|
||||
return self.decode_code(t, b)
|
||||
|
||||
|
||||
|
@ -295,5 +293,5 @@ if __name__ == '__main__':
|
|||
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
|
||||
#res=model(torch.randn(1,80,2048))
|
||||
e = model.encode_only_quantized(torch.randn(1, 80, 2048))
|
||||
model.decode_code_joined(e)
|
||||
print(res[0].shape)
|
||||
k = model.decode_code_joined(e)
|
||||
print(k.shape)
|
|
@ -51,7 +51,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
|
|
@ -300,7 +300,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_audio_clips.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user