forked from mrq/tortoise-tts
Fix bug with k>1
This commit is contained in:
parent
cdf44d7506
commit
2888ae0337
|
@ -416,7 +416,8 @@ class TextToSpeech:
|
||||||
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
# inputs. Re-produce those for the top results. This could be made more efficient by storing all of these
|
||||||
# results, but will increase memory usage.
|
# results, but will increase memory usage.
|
||||||
self.autoregressive = self.autoregressive.cuda()
|
self.autoregressive = self.autoregressive.cuda()
|
||||||
best_latents = self.autoregressive(auto_conditioning, text_tokens, torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
best_latents = self.autoregressive(auto_conditioning.repeat(k, 1), text_tokens.repeat(k, 1),
|
||||||
|
torch.tensor([text_tokens.shape[-1]], device=text_tokens.device), best_results,
|
||||||
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
torch.tensor([best_results.shape[-1]*self.autoregressive.mel_length_compression], device=text_tokens.device),
|
||||||
return_latent=True, clip_inputs=False)
|
return_latent=True, clip_inputs=False)
|
||||||
self.autoregressive = self.autoregressive.cpu()
|
self.autoregressive = self.autoregressive.cpu()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user