diff --git a/codes/scripts/audio/gen/speech_synthesis_utils.py b/codes/scripts/audio/gen/speech_synthesis_utils.py index b72bfd2f..513b09aa 100644 --- a/codes/scripts/audio/gen/speech_synthesis_utils.py +++ b/codes/scripts/audio/gen/speech_synthesis_utils.py @@ -29,6 +29,17 @@ def load_univnet_vocoder(): return model +def load_clvp(): + from models.clip.text_voice_clip import VoiceCLIP + clvp = VoiceCLIP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20, + text_seq_len=350, text_heads=12, num_speech_tokens=8192, speech_enc_depth=20, + speech_heads=12, speech_seq_len=430, text_mask_percentage=0, voice_mask_percentage=0, + use_xformers=True) + clvp.load_state_dict(torch.load(f"../experiments/clvp_md.pth", map_location=torch.device('cpu'))) + clvp = clvp.eval() + return clvp + + def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'): """ Converts an audio clip into a MEL tensor that the vocoder, DVAE and GptTts models use whenever a MEL is called for. diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index e3e4f152..6dfbe823 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -17,7 +17,7 @@ from data.audio.voice_tokenizer import VoiceBpeTokenizer from models.clip.mel_text_clip import MelTextCLIP from models.audio.tts.tacotron2 import text_to_sequence from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \ - convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel + convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel, load_clvp from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate @@ -47,6 +47,7 @@ class AudioDiffusionFid(evaluator.Evaluator): self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule, enable_conditioning_free_guidance=conditioning_free_diffusion_enabled, conditioning_free_k=conditioning_free_k) + self.bpe_tokenizer = VoiceBpeTokenizer('../experiments/bpe_lowercase_asr_256.json') self.dev = self.env['device'] mode = opt_get(opt_eval, ['diffusion_type'], 'tts') self.local_modules = {} @@ -61,9 +62,9 @@ class AudioDiffusionFid(evaluator.Evaluator): elif mode == 'ctc_to_mel': self.diffusion_fn = self.perform_diffusion_ctc self.local_modules['vocoder'] = load_univnet_vocoder().cpu() + self.local_modules['clvp'] = load_clvp() elif 'tts9_mel' in mode: mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth') - self.bpe_tokenizer = VoiceBpeTokenizer('../experiments/bpe_lowercase_asr_256.json') self.local_modules['dvae'] = load_speech_dvae().cpu() self.local_modules['vocoder'] = load_univnet_vocoder().cpu() self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes @@ -173,6 +174,9 @@ class AudioDiffusionFid(evaluator.Evaluator): def perform_diffusion_ctc(self, audio, codes, text): SAMPLE_RATE = 24000 + text_codes = torch.LongTensor(self.bpe_tokenizer.encode(text)).unsqueeze(0).to(audio.device) + clvp_latent = self.local_modules['clvp'].embed_text(text_codes) + real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0) univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=True) output_shape = univnet_mel.shape @@ -180,7 +184,8 @@ class AudioDiffusionFid(evaluator.Evaluator): 'true_normalization': True, 'in': 'in', 'out': 'out'}, {})({'in': audio})['out'] gen_mel = self.diffuser.p_sample_loop(self.model, output_shape, model_kwargs={'codes': codes.unsqueeze(0), - 'conditioning_input': cond_mel}) + 'conditioning_input': cond_mel, 'type': torch.tensor([0], device=codes.device), + 'clvp_input': clvp_latent}) gen_mel_denorm = denormalize_mel(gen_mel) gen_wav = self.local_modules['vocoder'].inference(gen_mel_denorm) diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index fb0a734d..98a52c19 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -343,13 +343,8 @@ class Mel2vecCodesInjector(Injector): class ClvpTextInjector(Injector): def __init__(self, opt, env): super().__init__(opt, env) - from models.clip.text_voice_clip import VoiceCLIP - self.clvp = VoiceCLIP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20, - text_seq_len=350, text_heads=12, num_speech_tokens=8192, speech_enc_depth=20, - speech_heads=12, speech_seq_len=430, text_mask_percentage=0, voice_mask_percentage=0, - use_xformers=True) - self.clvp.load_state_dict(torch.load(f"../experiments/clvp_md.pth", map_location=torch.device('cpu'))) - self.clvp = self.clvp.eval() + from scripts.audio.gen.speech_synthesis_utils import load_clvp + self.clvp = load_clvp() del self.clvp.speech_transformer # We will only be using the text transformer. self.needs_move = True diff --git a/codes/utils/music_utils.py b/codes/utils/music_utils.py index 6b4ca673..79c2f8c9 100644 --- a/codes/utils/music_utils.py +++ b/codes/utils/music_utils.py @@ -18,4 +18,4 @@ def get_music_codegen(): disable_custom_linear_init=True, do_reconstruction_loss=True) model.load_state_dict(torch.load(f"../experiments/m2v_music.pth", map_location=torch.device('cpu'))) model = model.eval() - return model \ No newline at end of file + return model