forked from mrq/DL-Art-School
Fixes
This commit is contained in:
parent
9b9f7ea61b
commit
937045cb63
|
@ -511,7 +511,7 @@ class UNetModel(nn.Module):
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||||
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
self.use_raw_y_as_embedding = use_raw_y_as_embedding
|
||||||
assert (self.num_classes is not None) != use_raw_y_as_embedding # These are mutually-exclusive.
|
assert not ((self.num_classes is not None) and use_raw_y_as_embedding) # These are mutually-exclusive.
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
|
|
@ -27,9 +27,7 @@ class GptTtsHf(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_mel_tokens = max_mel_tokens
|
self.max_mel_tokens = max_mel_tokens
|
||||||
self.max_symbols_per_phrase = max_symbols_per_phrase
|
self.max_symbols_per_phrase = max_symbols_per_phrase
|
||||||
|
|
||||||
self.model_dim = model_dim
|
self.model_dim = model_dim
|
||||||
self.max_mel_tokens = max_mel_tokens
|
|
||||||
self.max_conditioning_inputs = max_conditioning_inputs
|
self.max_conditioning_inputs = max_conditioning_inputs
|
||||||
self.mel_length_compression = mel_length_compression
|
self.mel_length_compression = mel_length_compression
|
||||||
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
self.conditioning_encoder = AudioMiniEncoder(80, model_dim)
|
||||||
|
@ -112,8 +110,9 @@ class GptTtsHf(nn.Module):
|
||||||
|
|
||||||
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8):
|
def inference(self, text_inputs, cond_inputs, do_sample=False, temperature=1.0, num_beams=8):
|
||||||
if not hasattr(self, 'inference_model'):
|
if not hasattr(self, 'inference_model'):
|
||||||
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.text_pos_embedding, self.final_norm, self.text_head)
|
self.inference_model = GPT2InferenceModel(self.gpt_config, self.gpt, self.mel_pos_embedding, self.final_norm, self.mel_head)
|
||||||
|
|
||||||
|
text_inputs = F.pad(text_inputs, (0, self.max_symbols_per_phrase - text_inputs.shape[1]), value=self.STOP_TEXT_TOKEN)
|
||||||
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.START_TEXT_TOKEN, self.STOP_TEXT_TOKEN)
|
||||||
text_emb = self.text_embedding(text_inputs)
|
text_emb = self.text_embedding(text_inputs)
|
||||||
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
text_emb = text_emb + self.text_pos_embedding(torch.arange(text_emb.shape[1], device=text_inputs.device))
|
||||||
|
@ -124,7 +123,7 @@ class GptTtsHf(nn.Module):
|
||||||
while len(conds) < self.max_conditioning_inputs:
|
while len(conds) < self.max_conditioning_inputs:
|
||||||
conds.append(conds[-1])
|
conds.append(conds[-1])
|
||||||
conds = torch.stack(conds, dim=1)
|
conds = torch.stack(conds, dim=1)
|
||||||
conds = conds + self.conditioning_embedding(torch.arange(conds.shape[1], device=conds.device))
|
conds = conds + self.conditioning_embedding
|
||||||
|
|
||||||
emb = torch.cat([text_emb, conds], dim=1)
|
emb = torch.cat([text_emb, conds], dim=1)
|
||||||
self.inference_model.store_mel_emb(emb)
|
self.inference_model.store_mel_emb(emb)
|
||||||
|
@ -133,8 +132,8 @@ class GptTtsHf(nn.Module):
|
||||||
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
fake_inputs[:,-1] = self.START_MEL_TOKEN
|
||||||
|
|
||||||
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
gen = self.inference_model.generate(fake_inputs, do_sample=do_sample, bos_token_id=self.START_MEL_TOKEN, pad_token_id=self.STOP_MEL_TOKEN, eos_token_id=self.STOP_MEL_TOKEN,
|
||||||
max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True)
|
max_length=emb.shape[1]+self.max_mel_tokens, temperature=temperature, num_beams=num_beams, use_cache=True, repetition_penalty=.2)
|
||||||
return gen[:, self.max_mel_frames:]
|
return gen[:, fake_inputs.shape[1]:]
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -10,12 +10,10 @@ from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from data.util import is_audio_file, find_files_of_type
|
from data.util import is_audio_file, find_files_of_type
|
||||||
from models.tacotron2.text import text_to_sequence
|
from models.tacotron2.text import text_to_sequence
|
||||||
from scripts.audio.gen.speech_synthesis_utils import do_spectrogram_diffusion, \
|
from scripts.audio.gen.speech_synthesis_utils import do_spectrogram_diffusion, \
|
||||||
load_discrete_vocoder_diffuser, wav_to_mel, convert_mel_to_codes
|
load_discrete_vocoder_diffuser, wav_to_mel
|
||||||
from trainer.injectors.base_injectors import MelSpectrogramInjector
|
from trainer.injectors.base_injectors import TorchMelSpectrogramInjector
|
||||||
from utils.audio import plot_spectrogram
|
|
||||||
from utils.options import Loader
|
from utils.options import Loader
|
||||||
from utils.util import load_model_from_config
|
from utils.util import load_model_from_config
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def do_vocoding(dvae, vocoder, diffuser, codes, cond=None, plot_spec=False):
|
def do_vocoding(dvae, vocoder, diffuser, codes, cond=None, plot_spec=False):
|
||||||
|
@ -34,9 +32,9 @@ def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length
|
||||||
elif gap > 0:
|
elif gap > 0:
|
||||||
rand_start = random.randint(0, gap)
|
rand_start = random.randint(0, gap)
|
||||||
rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
|
rel_clip = rel_clip[:, rand_start:rand_start + cond_length]
|
||||||
mel_clip = MelSpectrogramInjector({'in': 'wav', 'out': 'mel'},{})({'wav': rel_clip.unsqueeze(0)})['mel'].squeeze(0)
|
mel_clip = wav_to_mel(rel_clip.unsqueeze(0)).squeeze(0)
|
||||||
related_mels.append(mel_clip)
|
related_mels.append(mel_clip)
|
||||||
return torch.stack(related_mels, dim=0)
|
return torch.stack(related_mels, dim=0).unsqueeze(0).cuda(), rel_clip.unsqueeze(0).cuda()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,7 +46,7 @@ if __name__ == '__main__':
|
||||||
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
|
||||||
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
|
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml')
|
||||||
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
|
||||||
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\22000_gpt.pth')
|
parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\23500_gpt.pth')
|
||||||
parser.add_argument('-text', type=str, help='Text to speak.', default="I'm a language model that has learned to speak.")
|
parser.add_argument('-text', type=str, help='Text to speak.', default="I'm a language model that has learned to speak.")
|
||||||
parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Z:\\clips\\books1\\3042_18_Holden__000000000')
|
parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Z:\\clips\\books1\\3042_18_Holden__000000000')
|
||||||
parser.add_argument('-num_cond', type=int, help='Number of conditioning samples to load.', default=3)
|
parser.add_argument('-num_cond', type=int, help='Number of conditioning samples to load.', default=3)
|
||||||
|
@ -62,10 +60,10 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
print("Loading data..")
|
print("Loading data..")
|
||||||
text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda()
|
text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda()
|
||||||
conds = load_conditioning_candidates(args.cond_path, args.num_cond).unsqueeze(0).cuda()
|
conds, cond_wav = load_conditioning_candidates(args.cond_path, args.num_cond)
|
||||||
|
|
||||||
print("Performing GPT inference..")
|
print("Performing GPT inference..")
|
||||||
codes = gpt.inference(text, conds, num_beams=4) #TODO: check the text length during training and match that during inference.
|
codes = gpt.inference(text, conds, num_beams=4)
|
||||||
|
|
||||||
# Delete the GPT TTS model to free up GPU memory
|
# Delete the GPT TTS model to free up GPU memory
|
||||||
del gpt
|
del gpt
|
||||||
|
@ -77,5 +75,5 @@ if __name__ == '__main__':
|
||||||
diffuser = load_discrete_vocoder_diffuser()
|
diffuser = load_discrete_vocoder_diffuser()
|
||||||
|
|
||||||
print("Performing vocoding..")
|
print("Performing vocoding..")
|
||||||
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, codes, conds[:, 0], spectrogram_compression_factor=128, plt_spec=True)
|
wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, codes, cond_wav, spectrogram_compression_factor=128, plt_spec=False)
|
||||||
torchaudio.save('gpt_tts_output.wav', wav.squeeze(0), 10025)
|
torchaudio.save('gpt_tts_output.wav', wav.squeeze(0).cpu(), 10025)
|
Loading…
Reference in New Issue
Block a user