forked from mrq/DL-Art-School
Fix (?) use_gpt_tts for unified_voice
This commit is contained in:
parent
3c4301f085
commit
10fd1110be
|
@ -119,8 +119,8 @@ class UnifiedGptVoice(nn.Module):
|
||||||
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
self.mel_pos_embedding = nn.Embedding(self.max_mel_tokens + 1, model_dim)
|
||||||
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
seq_length = 2+max_text_tokens+self.max_mel_tokens+self.max_conditioning_inputs
|
||||||
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
self.gpt_config = GPT2Config(vocab_size=self.number_mel_codes,
|
||||||
n_positions=seq_length,
|
n_positions=seq_length-100, # -100 is a hack for backwards compatibility. TODO: remove at some point.
|
||||||
n_ctx=seq_length,
|
n_ctx=seq_length-100,
|
||||||
n_embd=model_dim,
|
n_embd=model_dim,
|
||||||
n_layer=layers,
|
n_layer=layers,
|
||||||
n_head=heads,
|
n_head=heads,
|
||||||
|
@ -285,10 +285,10 @@ class UnifiedGptVoice(nn.Module):
|
||||||
|
|
||||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_paired_embedding, self.final_norm, self.mel_head)
|
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
||||||
|
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
|
||||||
text_emb = self.text_embedding(text_inputs) + self.text_pos_paired_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(torch.arange(text_inputs.shape[1], device=text_inputs.device))
|
||||||
|
|
||||||
if self.shuffle_conditioning:
|
if self.shuffle_conditioning:
|
||||||
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
||||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
from models.gpt_voice.dvae_arch_playground.discretization_loss import DiscretizationLoss
|
from models.gpt_voice.lucidrains_dvae import DiscretizationLoss
|
||||||
from models.vqvae.vector_quantizer import VectorQuantize
|
from models.vqvae.vector_quantizer import VectorQuantize
|
||||||
from models.vqvae.vqvae import Quantize
|
from models.vqvae.vqvae import Quantize
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
|
|
@ -86,14 +86,15 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
|
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
|
||||||
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae_full\\models\\6100_generator_ema.pth')
|
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae_full\\models\\6100_generator_ema.pth')
|
||||||
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
||||||
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_voice.yml')
|
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts.yml')
|
||||||
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
||||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_voice\\models\\54000_gpt.pth')
|
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_finetune_tts_libri_all_and_hifi_no_unsupervised\\models\\4000_gpt.pth')
|
||||||
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
||||||
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
|
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
|
||||||
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
|
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
|
||||||
parser.add_argument('-num_samples', type=int, help='How many outputs to produce.', default=1)
|
parser.add_argument('-num_samples', type=int, help='How many outputs to produce.', default=1)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
# libritts_text = 'fall passed so quickly, there was so much going on around him, the tree quite forgot to look to himself.'
|
||||||
|
|
||||||
print("Loading GPT TTS..")
|
print("Loading GPT TTS..")
|
||||||
with open(args.opt_gpt_tts, mode='r') as f:
|
with open(args.opt_gpt_tts, mode='r') as f:
|
||||||
|
@ -103,11 +104,10 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
print("Loading data..")
|
print("Loading data..")
|
||||||
tokenizer = CharacterTokenizer()
|
tokenizer = CharacterTokenizer()
|
||||||
text = torch.IntTensor(tokenizer.encode(args.text.strip().lower()).ids).unsqueeze(0).cuda()
|
text = torch.IntTensor(tokenizer.encode(args.text)).unsqueeze(0).cuda()
|
||||||
|
text = F.pad(text, (0,1)) # This may not be necessary.
|
||||||
paired_text_length = gpt_opt['datasets']['train']['max_paired_text_length']
|
paired_text_length = gpt_opt['datasets']['train']['max_paired_text_length']
|
||||||
padding_needed = paired_text_length - text.shape[1]
|
assert paired_text_length >= text.shape[1]
|
||||||
assert padding_needed > 0
|
|
||||||
text = F.pad(text, (0,padding_needed))
|
|
||||||
|
|
||||||
cond_path = args.cond_path if args.cond_preset is None else preselected_cond_voices[args.cond_preset]
|
cond_path = args.cond_path if args.cond_preset is None else preselected_cond_voices[args.cond_preset]
|
||||||
conds, cond_wav = load_conditioning(cond_path)
|
conds, cond_wav = load_conditioning(cond_path)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user