Make gpttts more configurable
This commit is contained in:
parent
a7496b661c
commit
e723137273
|
@ -7,6 +7,7 @@ from models.gpt_voice.min_gpt import GPT, GPTConfig
|
||||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||||
from models.tacotron2.text import symbols
|
from models.tacotron2.text import symbols
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
class GptTts(nn.Module):
|
class GptTts(nn.Module):
|
||||||
|
@ -17,9 +18,8 @@ class GptTts(nn.Module):
|
||||||
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
|
MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3
|
||||||
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
|
MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, layers=8, model_dim=512, heads=8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_dim = 512
|
|
||||||
max_mel_frames = 900 * 1 // 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens.
|
max_mel_frames = 900 * 1 // 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens.
|
||||||
|
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
|
@ -29,7 +29,7 @@ class GptTts(nn.Module):
|
||||||
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
|
self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim)
|
||||||
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
|
self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim)
|
||||||
#self.gpt = GPT(GPTConfig(1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, n_layer=8, n_embd=model_dim, n_head=8), do_pos_emb=False)
|
#self.gpt = GPT(GPTConfig(1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, n_layer=8, n_embd=model_dim, n_head=8), do_pos_emb=False)
|
||||||
self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8)
|
self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=heads)
|
||||||
|
|
||||||
self.final_norm = nn.LayerNorm(model_dim)
|
self.final_norm = nn.LayerNorm(model_dim)
|
||||||
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS)
|
||||||
|
@ -114,7 +114,7 @@ class GptTts(nn.Module):
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_gpt_tts(opt_net, opt):
|
def register_gpt_tts(opt_net, opt):
|
||||||
return GptTts()
|
return GptTts(**opt_get(opt_net, ['kwargs'], {})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue
Block a user