forked from mrq/DL-Art-School
Rework how DVAE tokens are ordered
It might make more sense to have top tokens, then bottom tokens with top tokens having different discretized values.
This commit is contained in:
parent
4017236ba9
commit
c0f61a2e15
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from models.gpt_voice.lucidrains_gpt import Transformer
|
||||
from models.gpt_voice.min_gpt import GPT, GPTConfig
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from models.tacotron2.text import symbols
|
||||
from trainer.networks import register_model
|
||||
|
@ -12,7 +13,7 @@ class GptTts(nn.Module):
|
|||
MAX_SYMBOLS_PER_PHRASE = 200
|
||||
NUMBER_SYMBOLS = len(symbols)
|
||||
NUMBER_TEXT_TOKENS = NUMBER_SYMBOLS + MAX_SYMBOLS_PER_PHRASE + 2
|
||||
MEL_DICTIONARY_SIZE = 512+3
|
||||
MEL_DICTIONARY_SIZE = 1024+3
|
||||
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
|
||||
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
|
||||
|
||||
|
@ -27,8 +28,8 @@ class GptTts(nn.Module):
|
|||
self.mel_embedding = nn.Embedding(self.MEL_DICTIONARY_SIZE, model_dim)
|
||||
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
|
||||
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
|
||||
#self.gpt = GPT(GPTConfig(1+max_symbols_per_phrase+max_mel_frames, n_embd=model_dim, n_head=8), do_pos_emb=False)
|
||||
self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
|
||||
self.gpt = GPT(GPTConfig(1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, n_layer=8, n_embd=model_dim, n_head=8), do_pos_emb=False)
|
||||
#self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
|
||||
|
||||
self.final_norm = nn.LayerNorm(model_dim)
|
||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||
|
|
|
@ -184,6 +184,7 @@ class VQVAE(nn.Module):
|
|||
self.unsqueeze_channels = in_channel == -1
|
||||
in_channel = abs(in_channel)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4, conv_module=conv_module)
|
||||
self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2, conv_module=conv_module)
|
||||
self.quantize_conv_t = conv_module(channel, codebook_dim, 1)
|
||||
|
@ -238,12 +239,9 @@ class VQVAE(nn.Module):
|
|||
|
||||
def encode_only_quantized(self, input):
|
||||
qt, qb, d, idt, idb = self.encode(input)
|
||||
# Interleave top and bottom so top comes first and bottom comes second, such that the output looks like
|
||||
# [t0,b0,b1,t1,b1,b2,t2,b3,b4....]
|
||||
b, s = idt.shape
|
||||
idt = idt.view(b, s, 1)
|
||||
idb = idb.reshape(b, 2, s).permute(0,2,1).contiguous()
|
||||
ids = torch.cat([idt, idb], dim=2).reshape(b, s*3)
|
||||
# Append top and bottom into the same sequence, adding the codebook length onto the top to discriminate it.
|
||||
idt += self.codebook_size
|
||||
ids = torch.cat([idt, idb], dim=1)
|
||||
return ids
|
||||
|
||||
def decode(self, quant_t, quant_b):
|
||||
|
@ -269,9 +267,9 @@ class VQVAE(nn.Module):
|
|||
assert s % 3 == 0 # If not, this tensor didn't come from encode_only_quantized.
|
||||
s = s // 3
|
||||
|
||||
input = input.reshape(b, s, 3).permute(0,2,1).contiguous()
|
||||
t = input[:,0,:]
|
||||
b = input[:,1:,:].reshape(b, 2*s)
|
||||
# This doesn't work with batching. TODO: fixme.
|
||||
t = input[:,:s] - self.codebook_size
|
||||
b = input[:,s:]
|
||||
return self.decode_code(t, b)
|
||||
|
||||
|
||||
|
@ -295,5 +293,5 @@ if __name__ == '__main__':
|
|||
model = VQVAE(in_channel=80, conv_module=nn.Conv1d, conv_transpose_module=nn.ConvTranspose1d)
|
||||
#res=model(torch.randn(1,80,2048))
|
||||
e = model.encode_only_quantized(torch.randn(1, 80, 2048))
|
||||
model.decode_code_joined(e)
|
||||
print(res[0].shape)
|
||||
k = model.decode_code_joined(e)
|
||||
print(k.shape)
|
|
@ -51,7 +51,7 @@ if __name__ == "__main__":
|
|||
torch.backends.cudnn.benchmark = True
|
||||
want_metrics = False
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_gpt_tts_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/test_vqvae_audio_lj.yml')
|
||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||
opt = option.dict_to_nonedict(opt)
|
||||
utils.util.loaded_options = opt
|
||||
|
|
|
@ -300,7 +300,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_audio_clips.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user