modify demo template to say F5 instead of YourTTS, swap LoRA comparison around to make the lora'd the base file, and the no-lora the suffix'd file

This commit is contained in:
mrq 2024-10-21 19:52:02 -05:00
parent 02dfc60ac3
commit 1a02cd5bce
4 changed files with 11 additions and 10 deletions

View File

@ -13,7 +13,7 @@
<th>Prompt</th> <th>Prompt</th>
<th>Our VALL-E</th> <th>Our VALL-E</th>
<th>Original VALL-E</th> <th>Original VALL-E</th>
<th>YourTTS</th> <th>F5-TTS</th>
<th>Ground Truth</th> <th>Ground Truth</th>
</tr> </tr>
</thead> </thead>

View File

@ -33,7 +33,7 @@ from .emb.qnt import decode_to_file
from tqdm import tqdm, trange from tqdm import tqdm, trange
def encode(path): def encode(path):
if path is None or path.exists(): if path is None or not path.exists():
return "" return ""
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8') return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
@ -117,11 +117,11 @@ def main():
# to-do: just make this mappable # to-do: just make this mappable
if args.comparison == "lora": if args.comparison == "lora":
comparison_kwargs["suffix"] = "lora" comparison_kwargs["suffix"] = "no_lora"
comparison_kwargs["titles"] = ["No LoRA", "LoRA"] comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
comparison_kwargs["disabled"]["use_lora"] = False comparison_kwargs["disabled"]["use_lora"] = True
comparison_kwargs["enabled"]["use_lora"] = True comparison_kwargs["enabled"]["use_lora"] = False
elif args.comparison == "entropix-sampling": elif args.comparison == "entropix-sampling":
comparison_kwargs["suffix"] = "entropix_sampling" comparison_kwargs["suffix"] = "entropix_sampling"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"] comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
@ -175,7 +175,7 @@ def main():
comparison_kwargs["disabled"]["amp"] = current_amp comparison_kwargs["disabled"]["amp"] = current_amp
comparison_kwargs["enabled"]["amp"] = other_amp comparison_kwargs["enabled"]["amp"] = other_amp
else: elif args.comparison:
raise Exception(f"Unrecognized comparison flag: {args.comparison}") raise Exception(f"Unrecognized comparison flag: {args.comparison}")
# read html template # read html template
@ -221,6 +221,7 @@ def main():
num = args.dataset_samples if args.dataset_samples else length num = args.dataset_samples if args.dataset_samples else length
for i in trange( num, desc="Sampling dataset for samples" ): for i in trange( num, desc="Sampling dataset for samples" ):
index = i if not cfg.dataset.sample_shuffle else random.randint( i, length )
batch = dataloader.dataset[i] batch = dataloader.dataset[i]
dir = args.demo_dir / args.dataset_dir_name / f'{i}' dir = args.demo_dir / args.dataset_dir_name / f'{i}'

View File

@ -302,8 +302,8 @@ class AR_NAR(Base):
# to-do: tune these values, maybe have it factor based on confidence scores or something # to-do: tune these values, maybe have it factor based on confidence scores or something
if low_temperature: if low_temperature:
enabled = n < low_temperature_range enabled = n < low_temperature_range
sampling_repetition_penalty = 1.35 if enabled else original_sampling_repetition_penalty sampling_repetition_penalty = 1.5 if enabled else original_sampling_repetition_penalty
sampling_repetition_penalty_decay = 0.5 if enabled else original_sampling_repetition_penalty_decay sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
sampling_temperature = original_sampling_temperature if enabled else 1.0 sampling_temperature = original_sampling_temperature if enabled else 1.0
inputs = self.inputs( inputs = self.inputs(

View File

@ -11,7 +11,7 @@ from dataclasses import asdict, dataclass, field
# Simple filter to modify a token's probability if it shows up in the past # Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once # `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is # `decay` is a factor that will exponentially apply to how far away it is
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ): def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False ):
if factor == 1.0 or previous is None: if factor == 1.0 or previous is None:
return logits return logits