more fixes
This commit is contained in:
parent
6dea7da7a8
commit
3cad1b8114
|
@ -18,7 +18,8 @@ from models.audio.tts.tacotron2 import text_to_sequence
|
|||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
|
||||
from trainer.injectors.audio_injectors import denormalize_tacotron_mel
|
||||
from utils.util import ceil_multiple, opt_get, load_model_from_config
|
||||
from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
|
||||
|
||||
|
||||
class AudioDiffusionFid(evaluator.Evaluator):
|
||||
"""
|
||||
|
@ -136,11 +137,12 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
def tts9_get_autoregressive_codes(self, mel, text):
|
||||
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||
text_codes = torch.LongTensor(self.bpe_tokenizer.encode(text)).unsqueeze(0).to(mel.device)
|
||||
cond_inputs = mel.unsqueeze(1)
|
||||
auto_latents = self.local_modules['autoregressive'].forward(cond_inputs, text_codes,
|
||||
cond_inputs = pad_or_truncate(mel, 132300//256).unsqueeze(1)
|
||||
mlc = self.local_modules['autoregressive'].mel_length_compression
|
||||
auto_latents = self.local_modules['autoregressive'](cond_inputs, text_codes,
|
||||
torch.tensor([text_codes.shape[-1]], device=mel.device),
|
||||
mel_codes,
|
||||
torch.tensor([mel_codes.shape[-1]], device=mel.device),
|
||||
torch.tensor([mel_codes.shape[-1]*mlc], device=mel.device),
|
||||
text_first=True, raw_mels=None, return_latent=True,
|
||||
clip_inputs=False)
|
||||
return auto_latents
|
||||
|
@ -283,10 +285,10 @@ if __name__ == '__main__':
|
|||
# 34k; conditioning_free: {'frechet_distance': tensor(1.4059, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(118.3377, device='cuda:0')}
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat_autoregressive_inputs.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\tts_flat_autoregressive_inputs_r2_initial\\models\\500_generator.pth').cuda()
|
||||
load_path='X:\\dlas\\experiments\\tts_flat_autoregressive_inputs_r2_initial\\models\\2000_generator.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel_autoin'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 561, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 563, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
|
@ -174,7 +174,7 @@ class GptVoiceLatentInjector(Injector):
|
|||
def forward(self, state):
|
||||
with torch.no_grad():
|
||||
mel_inputs = self.to_mel(state[self.input])
|
||||
state_cond = pad_or_truncate(state[self.conditioning_key], 88000)
|
||||
state_cond = pad_or_truncate(state[self.conditioning_key], 132300)
|
||||
mel_conds = []
|
||||
for k in range(state_cond.shape[1]):
|
||||
mel_conds.append(self.to_mel(state_cond[:, k]))
|
||||
|
@ -184,9 +184,9 @@ class GptVoiceLatentInjector(Injector):
|
|||
self.dvae = self.dvae.to(mel_inputs.device)
|
||||
self.gpt = self.gpt.to(mel_inputs.device)
|
||||
codes = self.dvae.get_codebook_indices(mel_inputs)
|
||||
latents = self.gpt.forward(mel_conds, state[self.text_input_key],
|
||||
state[self.text_lengths_key], codes, state[self.input_lengths_key],
|
||||
text_first=True, raw_mels=None, return_attentions=False, return_latent=True,
|
||||
clip_inputs=False)
|
||||
latents = self.gpt(mel_conds, state[self.text_input_key],
|
||||
state[self.text_lengths_key], codes, state[self.input_lengths_key],
|
||||
text_first=True, raw_mels=None, return_attentions=False, return_latent=True,
|
||||
clip_inputs=False)
|
||||
assert latents.shape[1] == codes.shape[1]
|
||||
return {self.output: latents}
|
||||
|
|
Loading…
Reference in New Issue
Block a user