This commit is contained in:
James Betker 2021-12-18 16:45:38 -07:00
parent 9b9f7ea61b
commit 937045cb63
3 changed files with 16 additions and 19 deletions

View File

@ -511,7 +511,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.use_raw_y_as_embedding = use_raw_y_as_embedding
assert (self.num_classes is not None) != use_raw_y_as_embedding # These are mutually-exclusive.
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
self.input_blocks = nn.ModuleList(
[

View File

@ -27,9 +27,7 @@ class GptTtsHf(nn.Module):
super().__init__()
self.max_mel_tokens = max_mel_tokens
self.max_symbols_per_phrase = max_symbols_per_phrase
self.model_dim = model_dim
self.max_mel_tokens = max_mel_tokens
self.max_conditioning_inputs = max_conditioning_inputs
self.mel_length_compression = mel_length_compression
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
@ -112,8 +110,9 @@ class GptTtsHf(nn.Module):
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8):
if not hasattr(self, 'inference_model'):
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
text_emb = self.text_embedding(text_inputs)
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
@ -124,7 +123,7 @@ class GptTtsHf(nn.Module):
while len(conds) < self.max_conditioning_inputs:
conds.append(conds[-1])
conds = torch.stack(conds, dim=1)
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
conds = conds + self.conditioning_embedding
emb = torch.cat([text_emb, conds], dim=1)
self.inference_model.store_mel_emb(emb)
@ -133,8 +132,8 @@ class GptTtsHf(nn.Module):
fake_inputs[:,-1] = self.START_MEL_TOKEN
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True)
return gen[:, self.max_mel_frames:]
max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True, repetition_penalty=.2)
return gen[:, fake_inputs.shape[1]:]
@register_model

View File

@ -1,8 +1,8 @@
import argparse
import os
import random
import torch
import torch.nn.functional as F
import torchaudio
import yaml
@ -10,12 +10,10 @@ from data.audio.unsupervised_audio_dataset import load_audio
from data.util import is_audio_file, find_files_of_type
from models.tacotron2.text import text_to_sequence
from scripts.audio.gen.speech_synthesis_utils import do_spectrogram_diffusion, \
load_discrete_vocoder_diffuser, wav_to_mel, convert_mel_to_codes
from trainer.injectors.base_injectors import MelSpectrogramInjector
from utils.audio import plot_spectrogram
load_discrete_vocoder_diffuser, wav_to_mel
from trainer.injectors.base_injectors import TorchMelSpectrogramInjector
from utils.options import Loader
from utils.util import load_model_from_config
import torch.nn.functional as F
def do_vocoding(dvae, vocoder, diffuser, codes, cond=None, plot_spec=False):
@ -34,9 +32,9 @@ def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length
elif gap > 0:
rand_start = random.randint(0, gap)
rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
mel_clip = MelSpectrogramInjector({'in': 'wav', 'out': 'mel'},{})({'wav': rel_clip.unsqueeze(0)})['mel'].squeeze(0)
mel_clip = wav_to_mel(rel_clip.unsqueeze(0)).squeeze(0)
related_mels.append(mel_clip)
return torch.stack(related_mels, dim=0)
return torch.stack(related_mels, dim=0).unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
@ -48,7 +46,7 @@ if __name__ == '__main__':
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\22000_gpt.pth')
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\23500_gpt.pth')
parser.add_argument('-text', type=str, help='Text to speak.', default="I'm a language model that has learned to speak.")
parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Z:\\clips\\books1\\3042_18_Holden__000000000')
parser.add_argument('-num_cond', type=int, help='Number of conditioning samples to load.', default=3)
@ -62,10 +60,10 @@ if __name__ == '__main__':
print("Loading data..")
text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda()
conds = load_conditioning_candidates(args.cond_path, args.num_cond).unsqueeze(0).cuda()
conds, cond_wav = load_conditioning_candidates(args.cond_path, args.num_cond)
print("Performing GPT inference..")
codes = gpt.inference(text, conds, num_beams=4) #TODO: check the text length during training and match that during inference.
codes = gpt.inference(text, conds, num_beams=4)
# Delete the GPT TTS model to free up GPU memory
del gpt
@ -77,5 +75,5 @@ if __name__ == '__main__':
diffuser = load_discrete_vocoder_diffuser()
print("Performing vocoding..")
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, codes, conds[:, 0], spectrogram_compression_factor=128, plt_spec=True)
torchaudio.save('gpt_tts_output.wav', wav.squeeze(0), 10025)
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, codes, cond_wav, spectrogram_compression_factor=128, plt_spec=False)
torchaudio.save('gpt_tts_output.wav', wav.squeeze(0).cpu(), 10025)