From e723137273a32724bcf9ca9d53996808e1f83ceb Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 6 Aug 2021 22:08:51 -0600 Subject: [PATCH] Make gpttts more configurable --- codes/models/gpt_voice/gpt_tts.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/codes/models/gpt_voice/gpt_tts.py b/codes/models/gpt_voice/gpt_tts.py index c76ee65e..6f38a3e6 100644 --- a/codes/models/gpt_voice/gpt_tts.py +++ b/codes/models/gpt_voice/gpt_tts.py @@ -7,6 +7,7 @@ from models.gpt_voice.min_gpt import GPT, GPTConfig from models.tacotron2.taco_utils import get_mask_from_lengths from models.tacotron2.text import symbols from trainer.networks import register_model +from utils.util import opt_get class GptTts(nn.Module): @@ -17,9 +18,8 @@ class GptTts(nn.Module): MEL_START_TOKEN = MEL_DICTIONARY_SIZE-3 MEL_STOP_TOKEN = MEL_DICTIONARY_SIZE-2 - def __init__(self): + def __init__(self, layers=8, model_dim=512, heads=8): super().__init__() - model_dim = 512 max_mel_frames = 900 * 1 // 4 # 900 is the max number of MEL frames. The VQVAE outputs 1/8 of the input mel as tokens. self.model_dim = model_dim @@ -29,7 +29,7 @@ class GptTts(nn.Module): self.text_pos_embedding = nn.Embedding(self.MAX_SYMBOLS_PER_PHRASE, model_dim) self.mel_pos_embedding = nn.Embedding(max_mel_frames, model_dim) #self.gpt = GPT(GPTConfig(1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, n_layer=8, n_embd=model_dim, n_head=8), do_pos_emb=False) - self.gpt = Transformer(dim=model_dim, depth=8, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=8) + self.gpt = Transformer(dim=model_dim, depth=layers, seq_len=1+self.MAX_SYMBOLS_PER_PHRASE+max_mel_frames, heads=heads) self.final_norm = nn.LayerNorm(model_dim) self.text_head = nn.Linear(model_dim, self.NUMBER_TEXT_TOKENS) @@ -114,7 +114,7 @@ class GptTts(nn.Module): @register_model def register_gpt_tts(opt_net, opt): - return GptTts() + return GptTts(**opt_get(opt_net, ['kwargs'], {}) if __name__ == '__main__':