This commit is contained in:
James Betker 2022-05-27 12:28:04 -06:00
parent b4269af61b
commit 31dec016e0
4 changed files with 22 additions and 11 deletions

View File

@ -29,6 +29,17 @@ def load_univnet_vocoder():
return model
def load_clvp():
from models.clip.text_voice_clip import VoiceCLIP
clvp = VoiceCLIP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20,
text_seq_len=350, text_heads=12, num_speech_tokens=8192, speech_enc_depth=20,
speech_heads=12, speech_seq_len=430, text_mask_percentage=0, voice_mask_percentage=0,
use_xformers=True)
clvp.load_state_dict(torch.load(f"../experiments/clvp_md.pth", map_location=torch.device('cpu')))
clvp = clvp.eval()
return clvp
def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'):
"""
Converts an audio clip into a MEL tensor that the vocoder, DVAE and GptTts models use whenever a MEL is called for.

View File

@ -17,7 +17,7 @@ from data.audio.voice_tokenizer import VoiceBpeTokenizer
from models.clip.mel_text_clip import MelTextCLIP
from models.audio.tts.tacotron2 import text_to_sequence
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel, load_clvp
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
@ -47,6 +47,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule,
enable_conditioning_free_guidance=conditioning_free_diffusion_enabled,
conditioning_free_k=conditioning_free_k)
self.bpe_tokenizer = VoiceBpeTokenizer('../experiments/bpe_lowercase_asr_256.json')
self.dev = self.env['device']
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
self.local_modules = {}
@ -61,9 +62,9 @@ class AudioDiffusionFid(evaluator.Evaluator):
elif mode == 'ctc_to_mel':
self.diffusion_fn = self.perform_diffusion_ctc
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
self.local_modules['clvp'] = load_clvp()
elif 'tts9_mel' in mode:
mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth')
self.bpe_tokenizer = VoiceBpeTokenizer('../experiments/bpe_lowercase_asr_256.json')
self.local_modules['dvae'] = load_speech_dvae().cpu()
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
@ -173,6 +174,9 @@ class AudioDiffusionFid(evaluator.Evaluator):
def perform_diffusion_ctc(self, audio, codes, text):
SAMPLE_RATE = 24000
text_codes = torch.LongTensor(self.bpe_tokenizer.encode(text)).unsqueeze(0).to(audio.device)
clvp_latent = self.local_modules['clvp'].embed_text(text_codes)
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=True)
output_shape = univnet_mel.shape
@ -180,7 +184,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
'true_normalization': True, 'in': 'in', 'out': 'out'}, {})({'in': audio})['out']
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape, model_kwargs={'codes': codes.unsqueeze(0),
'conditioning_input': cond_mel})
'conditioning_input': cond_mel, 'type': torch.tensor([0], device=codes.device),
'clvp_input': clvp_latent})
gen_mel_denorm = denormalize_mel(gen_mel)
gen_wav = self.local_modules['vocoder'].inference(gen_mel_denorm)

View File

@ -343,13 +343,8 @@ class Mel2vecCodesInjector(Injector):
class ClvpTextInjector(Injector):
def __init__(self, opt, env):
super().__init__(opt, env)
from models.clip.text_voice_clip import VoiceCLIP
self.clvp = VoiceCLIP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20,
text_seq_len=350, text_heads=12, num_speech_tokens=8192, speech_enc_depth=20,
speech_heads=12, speech_seq_len=430, text_mask_percentage=0, voice_mask_percentage=0,
use_xformers=True)
self.clvp.load_state_dict(torch.load(f"../experiments/clvp_md.pth", map_location=torch.device('cpu')))
self.clvp = self.clvp.eval()
from scripts.audio.gen.speech_synthesis_utils import load_clvp
self.clvp = load_clvp()
del self.clvp.speech_transformer # We will only be using the text transformer.
self.needs_move = True

View File

@ -18,4 +18,4 @@ def get_music_codegen():
disable_custom_linear_init=True, do_reconstruction_loss=True)
model.load_state_dict(torch.load(f"../experiments/m2v_music.pth", map_location=torch.device('cpu')))
model = model.eval()
return model
return model