Fix bug with k>1

This commit is contained in:
James Betker 2022-05-02 18:00:22 -06:00
parent cdf44d7506
commit 2888ae0337

View File

@ -416,7 +416,8 @@ class TextToSpeech:
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
# results, but will increase memory usage. # results, but will increase memory usage.
self.autoregressive = self.autoregressive.cuda() self.autoregressive = self.autoregressive.cuda()
best_latents = self.autoregressive(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
return_latent=True, clip_inputs=False) return_latent=True, clip_inputs=False)
self.autoregressive = self.autoregressive.cpu() self.autoregressive = self.autoregressive.cpu()