adf
This commit is contained in:
parent
b4269af61b
commit
31dec016e0
codes
scripts/audio/gen
trainer
utils
|
@ -29,6 +29,17 @@ def load_univnet_vocoder():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_clvp():
|
||||||
|
from models.clip.text_voice_clip import VoiceCLIP
|
||||||
|
clvp = VoiceCLIP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20,
|
||||||
|
text_seq_len=350, text_heads=12, num_speech_tokens=8192, speech_enc_depth=20,
|
||||||
|
speech_heads=12, speech_seq_len=430, text_mask_percentage=0, voice_mask_percentage=0,
|
||||||
|
use_xformers=True)
|
||||||
|
clvp.load_state_dict(torch.load(f"../experiments/clvp_md.pth", map_location=torch.device('cpu')))
|
||||||
|
clvp = clvp.eval()
|
||||||
|
return clvp
|
||||||
|
|
||||||
|
|
||||||
def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'):
|
def wav_to_mel(wav, mel_norms_file='../experiments/clips_mel_norms.pth'):
|
||||||
"""
|
"""
|
||||||
Converts an audio clip into a MEL tensor that the vocoder, DVAE and GptTts models use whenever a MEL is called for.
|
Converts an audio clip into a MEL tensor that the vocoder, DVAE and GptTts models use whenever a MEL is called for.
|
||||||
|
|
|
@ -17,7 +17,7 @@ from data.audio.voice_tokenizer import VoiceBpeTokenizer
|
||||||
from models.clip.mel_text_clip import MelTextCLIP
|
from models.clip.mel_text_clip import MelTextCLIP
|
||||||
from models.audio.tts.tacotron2 import text_to_sequence
|
from models.audio.tts.tacotron2 import text_to_sequence
|
||||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel, load_clvp
|
||||||
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector
|
from trainer.injectors.audio_injectors import denormalize_mel, TorchMelSpectrogramInjector
|
||||||
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
||||||
|
|
||||||
|
@ -47,6 +47,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule,
|
self.diffuser = load_discrete_vocoder_diffuser(desired_diffusion_steps=diffusion_steps, schedule=diffusion_schedule,
|
||||||
enable_conditioning_free_guidance=conditioning_free_diffusion_enabled,
|
enable_conditioning_free_guidance=conditioning_free_diffusion_enabled,
|
||||||
conditioning_free_k=conditioning_free_k)
|
conditioning_free_k=conditioning_free_k)
|
||||||
|
self.bpe_tokenizer = VoiceBpeTokenizer('../experiments/bpe_lowercase_asr_256.json')
|
||||||
self.dev = self.env['device']
|
self.dev = self.env['device']
|
||||||
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
mode = opt_get(opt_eval, ['diffusion_type'], 'tts')
|
||||||
self.local_modules = {}
|
self.local_modules = {}
|
||||||
|
@ -61,9 +62,9 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
elif mode == 'ctc_to_mel':
|
elif mode == 'ctc_to_mel':
|
||||||
self.diffusion_fn = self.perform_diffusion_ctc
|
self.diffusion_fn = self.perform_diffusion_ctc
|
||||||
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
||||||
|
self.local_modules['clvp'] = load_clvp()
|
||||||
elif 'tts9_mel' in mode:
|
elif 'tts9_mel' in mode:
|
||||||
mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth')
|
mel_means, self.mel_max, self.mel_min, mel_stds, mel_vars = torch.load('../experiments/univnet_mel_norms.pth')
|
||||||
self.bpe_tokenizer = VoiceBpeTokenizer('../experiments/bpe_lowercase_asr_256.json')
|
|
||||||
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
self.local_modules['dvae'] = load_speech_dvae().cpu()
|
||||||
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
self.local_modules['vocoder'] = load_univnet_vocoder().cpu()
|
||||||
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
|
self.diffusion_fn = self.perform_diffusion_tts9_mel_from_codes
|
||||||
|
@ -173,6 +174,9 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
def perform_diffusion_ctc(self, audio, codes, text):
|
def perform_diffusion_ctc(self, audio, codes, text):
|
||||||
SAMPLE_RATE = 24000
|
SAMPLE_RATE = 24000
|
||||||
|
text_codes = torch.LongTensor(self.bpe_tokenizer.encode(text)).unsqueeze(0).to(audio.device)
|
||||||
|
clvp_latent = self.local_modules['clvp'].embed_text(text_codes)
|
||||||
|
|
||||||
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
|
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
|
||||||
univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=True)
|
univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=True)
|
||||||
output_shape = univnet_mel.shape
|
output_shape = univnet_mel.shape
|
||||||
|
@ -180,7 +184,8 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
||||||
'true_normalization': True, 'in': 'in', 'out': 'out'}, {})({'in': audio})['out']
|
'true_normalization': True, 'in': 'in', 'out': 'out'}, {})({'in': audio})['out']
|
||||||
|
|
||||||
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape, model_kwargs={'codes': codes.unsqueeze(0),
|
gen_mel = self.diffuser.p_sample_loop(self.model, output_shape, model_kwargs={'codes': codes.unsqueeze(0),
|
||||||
'conditioning_input': cond_mel})
|
'conditioning_input': cond_mel, 'type': torch.tensor([0], device=codes.device),
|
||||||
|
'clvp_input': clvp_latent})
|
||||||
gen_mel_denorm = denormalize_mel(gen_mel)
|
gen_mel_denorm = denormalize_mel(gen_mel)
|
||||||
|
|
||||||
gen_wav = self.local_modules['vocoder'].inference(gen_mel_denorm)
|
gen_wav = self.local_modules['vocoder'].inference(gen_mel_denorm)
|
||||||
|
|
|
@ -343,13 +343,8 @@ class Mel2vecCodesInjector(Injector):
|
||||||
class ClvpTextInjector(Injector):
|
class ClvpTextInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super().__init__(opt, env)
|
super().__init__(opt, env)
|
||||||
from models.clip.text_voice_clip import VoiceCLIP
|
from scripts.audio.gen.speech_synthesis_utils import load_clvp
|
||||||
self.clvp = VoiceCLIP(dim_text=768, dim_speech=768, dim_latent=768, num_text_tokens=256, text_enc_depth=20,
|
self.clvp = load_clvp()
|
||||||
text_seq_len=350, text_heads=12, num_speech_tokens=8192, speech_enc_depth=20,
|
|
||||||
speech_heads=12, speech_seq_len=430, text_mask_percentage=0, voice_mask_percentage=0,
|
|
||||||
use_xformers=True)
|
|
||||||
self.clvp.load_state_dict(torch.load(f"../experiments/clvp_md.pth", map_location=torch.device('cpu')))
|
|
||||||
self.clvp = self.clvp.eval()
|
|
||||||
del self.clvp.speech_transformer # We will only be using the text transformer.
|
del self.clvp.speech_transformer # We will only be using the text transformer.
|
||||||
self.needs_move = True
|
self.needs_move = True
|
||||||
|
|
||||||
|
|
|
@ -18,4 +18,4 @@ def get_music_codegen():
|
||||||
disable_custom_linear_init=True, do_reconstruction_loss=True)
|
disable_custom_linear_init=True, do_reconstruction_loss=True)
|
||||||
model.load_state_dict(torch.load(f"../experiments/m2v_music.pth", map_location=torch.device('cpu')))
|
model.load_state_dict(torch.load(f"../experiments/m2v_music.pth", map_location=torch.device('cpu')))
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user