From dee34f096c95ab769b921c2acb349709518a9195 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 16 Dec 2021 23:28:54 -0700 Subject: [PATCH] Add use_gpt_tts script --- codes/scripts/audio/gen/use_gpt_tts.py | 81 ++++++++++++++++++++++++++ codes/utils/util.py | 9 ++- 2 files changed, 87 insertions(+), 3 deletions(-) create mode 100644 codes/scripts/audio/gen/use_gpt_tts.py diff --git a/codes/scripts/audio/gen/use_gpt_tts.py b/codes/scripts/audio/gen/use_gpt_tts.py new file mode 100644 index 00000000..e846f478 --- /dev/null +++ b/codes/scripts/audio/gen/use_gpt_tts.py @@ -0,0 +1,81 @@ +import argparse +import os +import random + +import torch +import torchaudio +import yaml + +from data.audio.unsupervised_audio_dataset import load_audio +from data.util import is_audio_file, find_files_of_type +from models.tacotron2.text import text_to_sequence +from scripts.audio.gen.speech_synthesis_utils import do_spectrogram_diffusion, \ + load_discrete_vocoder_diffuser, wav_to_mel, convert_mel_to_codes +from trainer.injectors.base_injectors import MelSpectrogramInjector +from utils.audio import plot_spectrogram +from utils.options import Loader +from utils.util import load_model_from_config +import torch.nn.functional as F + + +def do_vocoding(dvae, vocoder, diffuser, codes, cond=None, plot_spec=False): + return + + +def load_conditioning_candidates(path, num_conds, sample_rate=22050, cond_length=44100): + candidates = find_files_of_type('img', path, qualifier=is_audio_file)[0] + # Sample with replacement. This can get repeats, but more conveniently handles situations where there are not enough candidates. + related_mels = [] + for k in range(num_conds): + rel_clip = load_audio(candidates[k], sample_rate) + gap = rel_clip.shape[-1] - cond_length + if gap < 0: + rel_clip = F.pad(rel_clip, pad=(0, abs(gap))) + elif gap > 0: + rand_start = random.randint(0, gap) + rel_clip = rel_clip[:, rand_start:rand_start + cond_length] + mel_clip = MelSpectrogramInjector({'in': 'wav', 'out': 'mel'},{})({'wav': rel_clip.unsqueeze(0)})['mel'].squeeze(0) + related_mels.append(mel_clip) + return torch.stack(related_mels, dim=0) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-opt_diffuse', type=str, help='Path to options YAML file used to train the diffusion model', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae.yml') + parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator') + parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='X:\\dlas\\experiments\\train_diffusion_vocoder_with_cond_new_dvae_full\\models\\6100_generator_ema.pth') + parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae') + parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts.yml') + parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt') + parser.add_argument('-gpt_tts_model_path', type=str, help='GPT TTS model checkpoint to load.', default='X:\\dlas\\experiments\\train_gpt_tts\\models\\22000_gpt.pth') + parser.add_argument('-text', type=str, help='Text to speak.', default="I'm a language model that has learned to speak.") + parser.add_argument('-cond_path', type=str, help='Folder containing conditioning samples.', default='Z:\\clips\\books1\\3042_18_Holden__000000000') + parser.add_argument('-num_cond', type=int, help='Number of conditioning samples to load.', default=3) + args = parser.parse_args() + + print("Loading GPT TTS..") + with open(args.opt_gpt_tts, mode='r') as f: + gpt_opt = yaml.load(f, Loader=Loader) + gpt_opt['networks'][args.gpt_tts_model_name]['kwargs']['checkpointing'] = False # Required for beam search + gpt = load_model_from_config(preloaded_options=gpt_opt, model_name=args.gpt_tts_model_name, also_load_savepoint=False, load_path=args.gpt_tts_model_path) + + print("Loading data..") + text = torch.IntTensor(text_to_sequence(args.text, ['english_cleaners'])).unsqueeze(0).cuda() + conds = load_conditioning_candidates(args.cond_path, args.num_cond).unsqueeze(0).cuda() + + print("Performing GPT inference..") + codes = gpt.inference(text, conds, num_beams=4) #TODO: check the text length during training and match that during inference. + + # Delete the GPT TTS model to free up GPU memory + del gpt + + print("Loading DVAE..") + dvae = load_model_from_config(args.opt_diffuse, args.dvae_model_name) + print("Loading Diffusion Model..") + diffusion = load_model_from_config(args.opt_diffuse, args.diffusion_model_name, also_load_savepoint=False, load_path=args.diffusion_model_path) + diffuser = load_discrete_vocoder_diffuser() + + print("Performing vocoding..") + wav = do_spectrogram_diffusion(diffusion, dvae, diffuser, codes, conds[:, 0], spectrogram_compression_factor=128, plt_spec=True) + torchaudio.save('gpt_tts_output.wav', wav.squeeze(0), 10025) \ No newline at end of file diff --git a/codes/utils/util.py b/codes/utils/util.py index c21500d3..24cb59a7 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -467,9 +467,12 @@ def clip_grad_norm(parameters: list, parameter_names: list, max_norm: float, nor Loader, Dumper = OrderedYaml() -def load_model_from_config(cfg_file, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None): - with open(cfg_file, mode='r') as f: - opt = yaml.load(f, Loader=Loader) +def load_model_from_config(cfg_file=None, model_name=None, dev='cuda', also_load_savepoint=True, load_path=None, preloaded_options=None): + if preloaded_options is not None: + opt = preloaded_options + else: + with open(cfg_file, mode='r') as f: + opt = yaml.load(f, Loader=Loader) if model_name is None: model_cfg = opt['networks'].values() model_name = next(opt['networks'].keys())