modify demo template to say F5 instead of YourTTS, swap LoRA comparison around to make the lora'd the base file, and the no-lora the suffix'd file
This commit is contained in:
parent
02dfc60ac3
commit
1a02cd5bce
|
@ -13,7 +13,7 @@
|
||||||
<th>Prompt</th>
|
<th>Prompt</th>
|
||||||
<th>Our VALL-E</th>
|
<th>Our VALL-E</th>
|
||||||
<th>Original VALL-E</th>
|
<th>Original VALL-E</th>
|
||||||
<th>YourTTS</th>
|
<th>F5-TTS</th>
|
||||||
<th>Ground Truth</th>
|
<th>Ground Truth</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
|
|
|
@ -33,7 +33,7 @@ from .emb.qnt import decode_to_file
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
def encode(path):
|
def encode(path):
|
||||||
if path is None or path.exists():
|
if path is None or not path.exists():
|
||||||
return ""
|
return ""
|
||||||
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
return "data:audio/wav;base64," + base64.b64encode(open(path, "rb").read()).decode('utf-8')
|
||||||
|
|
||||||
|
@ -117,11 +117,11 @@ def main():
|
||||||
|
|
||||||
# to-do: just make this mappable
|
# to-do: just make this mappable
|
||||||
if args.comparison == "lora":
|
if args.comparison == "lora":
|
||||||
comparison_kwargs["suffix"] = "lora"
|
comparison_kwargs["suffix"] = "no_lora"
|
||||||
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
|
comparison_kwargs["titles"] = ["LoRA", "No LoRA"]
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["use_lora"] = False
|
comparison_kwargs["disabled"]["use_lora"] = True
|
||||||
comparison_kwargs["enabled"]["use_lora"] = True
|
comparison_kwargs["enabled"]["use_lora"] = False
|
||||||
elif args.comparison == "entropix-sampling":
|
elif args.comparison == "entropix-sampling":
|
||||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||||
|
@ -175,7 +175,7 @@ def main():
|
||||||
|
|
||||||
comparison_kwargs["disabled"]["amp"] = current_amp
|
comparison_kwargs["disabled"]["amp"] = current_amp
|
||||||
comparison_kwargs["enabled"]["amp"] = other_amp
|
comparison_kwargs["enabled"]["amp"] = other_amp
|
||||||
else:
|
elif args.comparison:
|
||||||
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
||||||
|
|
||||||
# read html template
|
# read html template
|
||||||
|
@ -221,6 +221,7 @@ def main():
|
||||||
num = args.dataset_samples if args.dataset_samples else length
|
num = args.dataset_samples if args.dataset_samples else length
|
||||||
|
|
||||||
for i in trange( num, desc="Sampling dataset for samples" ):
|
for i in trange( num, desc="Sampling dataset for samples" ):
|
||||||
|
index = i if not cfg.dataset.sample_shuffle else random.randint( i, length )
|
||||||
batch = dataloader.dataset[i]
|
batch = dataloader.dataset[i]
|
||||||
|
|
||||||
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
|
dir = args.demo_dir / args.dataset_dir_name / f'{i}'
|
||||||
|
|
|
@ -302,8 +302,8 @@ class AR_NAR(Base):
|
||||||
# to-do: tune these values, maybe have it factor based on confidence scores or something
|
# to-do: tune these values, maybe have it factor based on confidence scores or something
|
||||||
if low_temperature:
|
if low_temperature:
|
||||||
enabled = n < low_temperature_range
|
enabled = n < low_temperature_range
|
||||||
sampling_repetition_penalty = 1.35 if enabled else original_sampling_repetition_penalty
|
sampling_repetition_penalty = 1.5 if enabled else original_sampling_repetition_penalty
|
||||||
sampling_repetition_penalty_decay = 0.5 if enabled else original_sampling_repetition_penalty_decay
|
sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
|
||||||
sampling_temperature = original_sampling_temperature if enabled else 1.0
|
sampling_temperature = original_sampling_temperature if enabled else 1.0
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
|
|
|
@ -11,7 +11,7 @@ from dataclasses import asdict, dataclass, field
|
||||||
# Simple filter to modify a token's probability if it shows up in the past
|
# Simple filter to modify a token's probability if it shows up in the past
|
||||||
# `one_time` will only apply the penalty once
|
# `one_time` will only apply the penalty once
|
||||||
# `decay` is a factor that will exponentially apply to how far away it is
|
# `decay` is a factor that will exponentially apply to how far away it is
|
||||||
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
|
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=False ):
|
||||||
if factor == 1.0 or previous is None:
|
if factor == 1.0 or previous is None:
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user