forked from mrq/tortoise-tts
Skip CLVP if cvvp_amount is 1
Also fixes formatting bug in log message
This commit is contained in:
parent
a52e3026ba
commit
b681fa9d11
|
@ -427,17 +427,21 @@ class TextToSpeech:
|
||||||
if self.cvvp is None:
|
if self.cvvp is None:
|
||||||
print("Computing best candidates using CLVP")
|
print("Computing best candidates using CLVP")
|
||||||
else:
|
else:
|
||||||
print(f"Computing best candidates using CLVP {int((1-cvvp_amount) * 100):02d}% and CVVP {int(cvvp_amount * 100):02d}%")
|
print(f"Computing best candidates using CLVP {((1-cvvp_amount) * 100):2.0f}% and CVVP {(cvvp_amount * 100):2.0f}%")
|
||||||
for batch in tqdm(samples, disable=not verbose):
|
for batch in tqdm(samples, disable=not verbose):
|
||||||
for i in range(batch.shape[0]):
|
for i in range(batch.shape[0]):
|
||||||
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
batch[i] = fix_autoregressive_output(batch[i], stop_mel_token)
|
||||||
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
if cvvp_amount != 1:
|
||||||
|
clvp = self.clvp(text_tokens.repeat(batch.shape[0], 1), batch, return_loss=False)
|
||||||
if auto_conds is not None and cvvp_amount > 0:
|
if auto_conds is not None and cvvp_amount > 0:
|
||||||
cvvp_accumulator = 0
|
cvvp_accumulator = 0
|
||||||
for cl in range(auto_conds.shape[1]):
|
for cl in range(auto_conds.shape[1]):
|
||||||
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
cvvp_accumulator = cvvp_accumulator + self.cvvp(auto_conds[:, cl].repeat(batch.shape[0], 1, 1), batch, return_loss=False)
|
||||||
cvvp = cvvp_accumulator / auto_conds.shape[1]
|
cvvp = cvvp_accumulator / auto_conds.shape[1]
|
||||||
clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount))
|
if cvvp_amount == 1:
|
||||||
|
clip_results.append(cvvp)
|
||||||
|
else:
|
||||||
|
clip_results.append(cvvp * cvvp_amount + clvp * (1-cvvp_amount))
|
||||||
else:
|
else:
|
||||||
clip_results.append(clvp)
|
clip_results.append(clvp)
|
||||||
clip_results = torch.cat(clip_results, dim=0)
|
clip_results = torch.cat(clip_results, dim=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user