From b681fa9d11fcbeb6b41a92846c77cb72b4309020 Mon Sep 17 00:00:00 2001 From: Johan Nordberg Date: Wed, 25 May 2022 11:12:53 +0000 Subject: [PATCH] Skip CLVP if cvvp_amount is 1 Also fixes formatting bug in log message --- tortoise/api.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index 59b1604..f3b729f 100644 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -427,17 +427,21 @@ class TextToSpeech: if self.cvvp is None: print("Computing best candidates using CLVP") else: - print(f"Computing best candidates using CLVP {int((1-cvvp_amount) * 100):02d}% and CVVP {int(cvvp_amount * 100):02d}%") + print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%") for batch in tqdm(samples, disable=not verbose): for i in range(batch.shape[0]): batch[i] = fix_autoregressive_output(batch[i], stop_mel_token) - clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) + if cvvp_amount != 1: + clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False) if auto_conds is not None and cvvp_amount > 0: cvvp_accumulator = 0 for cl in range(auto_conds.shape[1]): cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False) cvvp = cvvp_accumulator / auto_conds.shape[1] - clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount)) + if cvvp_amount == 1: + clip_results.append(cvvp) + else: + clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount)) else: clip_results.append(clvp) clip_results = torch.cat(clip_results, dim=0)