forked from mrq/DL-Art-School
Add inference for unified gpt
This commit is contained in:
parent
ead2a74bf0
commit
8e26400ce2
|
@ -40,7 +40,7 @@ class ConditioningEncoder(nn.Module):
|
|||
|
||||
class UnifiedGptVoice(nn.Module):
|
||||
"""
|
||||
Derived from GptTtsHf, but offers multiple modes of operation:
|
||||
Derived from GptTtsHf, but offers multiple modes of autoregressive operation:
|
||||
- Text only
|
||||
- Voice only
|
||||
- Text conditioned on voice
|
||||
|
@ -192,6 +192,28 @@ class UnifiedGptVoice(nn.Module):
|
|||
loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
|
||||
return loss_mel.mean()
|
||||
|
||||
def inference_speech(self, speech_conditioning_input, text_inputs, **hf_generate_kwargs):
|
||||
if not hasattr(self, 'inference_model'):
|
||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, None, self.final_norm, self.mel_head)
|
||||
|
||||
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
|
||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||
text_emb = self.text_embedding(text_inputs)
|
||||
|
||||
# Randomly permute the conditioning spectrogram, to destroy any structure present.
|
||||
speech_conditioning_input = self.randomly_permute_conditioning_input(speech_conditioning_input)
|
||||
cond = self.conditioning_encoder(speech_conditioning_input).unsqueeze(1)
|
||||
|
||||
emb = torch.cat([cond, text_emb], dim=1)
|
||||
self.inference_model.store_mel_emb(emb)
|
||||
|
||||
fake_inputs = torch.full((emb.shape[0],emb.shape[1]+1,), fill_value=1, dtype=torch.long, device=text_inputs.device)
|
||||
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||
|
||||
gen = self.inference_model.generate(fake_inputs, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||
max_length=emb.shape[1]+self.max_mel_tokens, **hf_generate_kwargs)
|
||||
return gen[:, fake_inputs.shape[1]:]
|
||||
|
||||
|
||||
@register_model
|
||||
def register_unified_gpt_voice(opt_net, opt):
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
import yaml
|
||||
from tokenizers import Tokenizer
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from data.util import is_audio_file, find_files_of_type
|
||||
|
@ -76,6 +77,7 @@ if __name__ == '__main__':
|
|||
'simmons': 'Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav',
|
||||
'news_girl': 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav',
|
||||
'dan_carlin': 'Y:\\clips\\books1\5_dchha06 Shield of the West\\00476.wav',
|
||||
'libri_test': 'Z:\\bigasr_dataset\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -83,13 +85,14 @@ if __name__ == '__main__':
|
|||
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
|
||||
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae_full\\models\\6100_generator_ema.pth')
|
||||
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
||||
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
|
||||
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_unified_voice.yml')
|
||||
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts_no_pos\\models\\50000_gpt.pth')
|
||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_unified_voice\\models\\15000_gpt.pth')
|
||||
parser.add_argument('-text', type=str, help='Text to speak.', default="I am a language model that has learned to speak.")
|
||||
parser.add_argument('-cond_path', type=str, help='Path to condioning sample.', default='')
|
||||
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='simmons')
|
||||
parser.add_argument('-cond_preset', type=str, help='Use a preset conditioning voice (defined above). Overrides cond_path.', default='libri_test')
|
||||
parser.add_argument('-num_samples', type=int, help='How many outputs to produce.', default=1)
|
||||
parser.add_argument('-tokenizer_vocab_file', type=str, help='Tokenizer vocabulary file used to train.', default='../experiments/custom_lowercase_gptvoice_tokenizer_r2.json')
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading GPT TTS..")
|
||||
|
@ -99,13 +102,14 @@ if __name__ == '__main__':
|
|||
gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path)
|
||||
|
||||
print("Loading data..")
|
||||
text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda()
|
||||
tokenizer = Tokenizer.from_file(args.tokenizer_vocab_file)
|
||||
text = torch.IntTensor(tokenizer.encode(args.text.strip().lower()).ids).unsqueeze(0).cuda()
|
||||
cond_path = args.cond_path if args.cond_preset is None else preselected_cond_voices[args.cond_preset]
|
||||
conds, cond_wav = load_conditioning(cond_path)
|
||||
|
||||
print("Performing GPT inference..")
|
||||
codes = gpt.inference(text, conds, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=20, top_p=.95,
|
||||
num_return_sequences=args.num_samples, length_penalty=.1, early_stopping=True)
|
||||
codes = gpt.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=20, top_p=.95,
|
||||
num_return_sequences=args.num_samples, length_penalty=1, early_stopping=True)
|
||||
|
||||
# Delete the GPT TTS model to free up GPU memory
|
||||
stop_token = gpt.STOP_MEL_TOKEN
|
||||
|
|
Loading…
Reference in New Issue
Block a user