From 95f679f4ba714c0f2a37d66f4ab8bc33f8b952d8 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 10 Oct 2023 15:30:08 +0000 Subject: [PATCH] possible fix for when candidates >= samples --- tortoise/api.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tortoise/api.py b/tortoise/api.py index 2973bcb..ef86f85 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -815,7 +815,10 @@ class TextToSpeech: clip_results = torch.cat(clip_results, dim=0) samples = torch.cat(samples, dim=0) - best_results = samples[torch.topk(clip_results, k=k).indices] + if k < num_autoregressive_samples: + best_results = samples[torch.topk(clip_results, k=k).indices] + else: + best_results = samples if not self.preloaded_tensors: self.clvp = migrate_to_device( self.clvp, 'cpu' )