From 2888ae0337e5c05d9990b7998317510886da16dc Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 2 May 2022 18:00:22 -0600 Subject: [PATCH] Fix bug with k>1 --- tortoise/api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tortoise/api.py b/tortoise/api.py index 5231484..ba05aa1 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -416,7 +416,8 @@ class TextToSpeech: # inputs. Re-produce those for the top results. This could be made more efficient by storing all of these # results, but will increase memory usage. self.autoregressive = self.autoregressive.cuda() - best_latents = self.autoregressive(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, + best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1), + torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results, torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device), return_latent=True, clip_inputs=False) self.autoregressive = self.autoregressive.cpu()