more fixes

This commit is contained in:
James Betker 2022-04-11 15:18:44 -06:00
parent 6dea7da7a8
commit 3cad1b8114
2 changed files with 13 additions and 11 deletions

View File

@ -18,7 +18,8 @@ from models.audio.tts.tacotron2 import text_to_sequence
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \ from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel convert_mel_to_codes, load_univnet_vocoder, wav_to_univnet_mel
from trainer.injectors.audio_injectors import denormalize_tacotron_mel from trainer.injectors.audio_injectors import denormalize_tacotron_mel
from utils.util import ceil_multiple, opt_get, load_model_from_config from utils.util import ceil_multiple, opt_get, load_model_from_config, pad_or_truncate
class AudioDiffusionFid(evaluator.Evaluator): class AudioDiffusionFid(evaluator.Evaluator):
""" """
@ -136,11 +137,12 @@ class AudioDiffusionFid(evaluator.Evaluator):
def tts9_get_autoregressive_codes(self, mel, text): def tts9_get_autoregressive_codes(self, mel, text):
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel) mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
text_codes = torch.LongTensor(self.bpe_tokenizer.encode(text)).unsqueeze(0).to(mel.device) text_codes = torch.LongTensor(self.bpe_tokenizer.encode(text)).unsqueeze(0).to(mel.device)
cond_inputs = mel.unsqueeze(1) cond_inputs = pad_or_truncate(mel, 132300//256).unsqueeze(1)
auto_latents = self.local_modules['autoregressive'].forward(cond_inputs, text_codes, mlc = self.local_modules['autoregressive'].mel_length_compression
auto_latents = self.local_modules['autoregressive'](cond_inputs, text_codes,
torch.tensor([text_codes.shape[-1]], device=mel.device), torch.tensor([text_codes.shape[-1]], device=mel.device),
mel_codes, mel_codes,
torch.tensor([mel_codes.shape[-1]], device=mel.device), torch.tensor([mel_codes.shape[-1]*mlc], device=mel.device),
text_first=True, raw_mels=None, return_latent=True, text_first=True, raw_mels=None, return_latent=True,
clip_inputs=False) clip_inputs=False)
return auto_latents return auto_latents
@ -283,10 +285,10 @@ if __name__ == '__main__':
# 34k; conditioning_free: {'frechet_distance': tensor(1.4059, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(118.3377, device='cuda:0')} # 34k; conditioning_free: {'frechet_distance': tensor(1.4059, device='cuda:0', dtype=torch.float64), 'intelligibility_loss': tensor(118.3377, device='cuda:0')}
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat_autoregressive_inputs.yml', 'generator', diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts_mel_flat_autoregressive_inputs.yml', 'generator',
also_load_savepoint=False, also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\tts_flat_autoregressive_inputs_r2_initial\\models\\500_generator.pth').cuda() load_path='X:\\dlas\\experiments\\tts_flat_autoregressive_inputs_r2_initial\\models\\2000_generator.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100, opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
'conditioning_free': True, 'conditioning_free_k': 1, 'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel_autoin'} 'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel_autoin'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 561, 'device': 'cuda', 'opt': {}} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 563, 'device': 'cuda', 'opt': {}}
eval = AudioDiffusionFid(diffusion, opt_eval, env) eval = AudioDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())

View File

@ -174,7 +174,7 @@ class GptVoiceLatentInjector(Injector):
def forward(self, state): def forward(self, state):
with torch.no_grad(): with torch.no_grad():
mel_inputs = self.to_mel(state[self.input]) mel_inputs = self.to_mel(state[self.input])
state_cond = pad_or_truncate(state[self.conditioning_key], 88000) state_cond = pad_or_truncate(state[self.conditioning_key], 132300)
mel_conds = [] mel_conds = []
for k in range(state_cond.shape[1]): for k in range(state_cond.shape[1]):
mel_conds.append(self.to_mel(state_cond[:, k])) mel_conds.append(self.to_mel(state_cond[:, k]))
@ -184,9 +184,9 @@ class GptVoiceLatentInjector(Injector):
self.dvae = self.dvae.to(mel_inputs.device) self.dvae = self.dvae.to(mel_inputs.device)
self.gpt = self.gpt.to(mel_inputs.device) self.gpt = self.gpt.to(mel_inputs.device)
codes = self.dvae.get_codebook_indices(mel_inputs) codes = self.dvae.get_codebook_indices(mel_inputs)
latents = self.gpt.forward(mel_conds, state[self.text_input_key], latents = self.gpt(mel_conds, state[self.text_input_key],
state[self.text_lengths_key], codes, state[self.input_lengths_key], state[self.text_lengths_key], codes, state[self.input_lengths_key],
text_first=True, raw_mels=None, return_attentions=False, return_latent=True, text_first=True, raw_mels=None, return_attentions=False, return_latent=True,
clip_inputs=False) clip_inputs=False)
assert latents.shape[1] == codes.shape[1] assert latents.shape[1] == codes.shape[1]
return {self.output: latents} return {self.output: latents}