This commit is contained in:
mrq 2024-10-18 09:40:06 -05:00
parent 75b90be325
commit 0dfab973e7
4 changed files with 63 additions and 45 deletions

View File

@ -167,6 +167,7 @@ class Dataset:
prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt
prompt_similar_top_k: int = 1 # top-k similar candidates to sample from
prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from
prompt_inject_noise: bool = False # adds noise to the input prompt waveform to try and vary things
resps_max_samples: int = 1 # number of samples to target for training
resps_append_p: float = 1.0 # probability to append another sample to the training target
@ -176,7 +177,6 @@ class Dataset:
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
noise_scale: float = 0.25 # scaling noise value
noise_inject_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
retokenize_text: bool = False
_frames_per_second: int = 0 # allows setting your own hint

View File

@ -1010,7 +1010,8 @@ class Dataset(_Dataset):
"""
prom_length = 0
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) if trim else 0
for _ in range(cfg.dataset.prompt_max_samples):
if reference is not None:
@ -1142,7 +1143,7 @@ class Dataset(_Dataset):
if task == "tts":
proms = self.sample_prompts(spkr_name, reference=path)
if cfg.dataset.inject_noise_in_prom:
if cfg.dataset.prompt_inject_noise:
# sample random noise
noise = self.sample_noise()
# extend the noise to fill the target audio
@ -1156,7 +1157,8 @@ class Dataset(_Dataset):
elif task == "tts-c":
# trim a piece of the output response
if naive:
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second)
proms = resps[:trim_length, :]
resps = resps[trim_length:, :]

View File

@ -87,7 +87,7 @@ def main():
parser.add_argument("--random-prompts", action="store_true")
parser.add_argument("--lora", action="store_true")
parser.add_argument("--comparison", action="store_true")
parser.add_argument("--comparison", type=str, default=None)
args = parser.parse_args()
@ -104,34 +104,47 @@ def main():
# comparison kwargs
comparison_kwargs = {
"enabled": False,
"titles": [],
"suffix": "_after",
"before": {},
"after": {}
"suffix": "diff",
"enabled": {},
"disabled": {}
}
if args.lora:
comparison_kwargs["enabled"] = True
comparison_kwargs["suffix"] = "_lora"
args.comparison = "lora"
# to-do: just make this mappable
if args.comparison == "lora":
comparison_kwargs["suffix"] = "lora"
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
comparison_kwargs["before"]["use_lora"] = True
comparison_kwargs["after"]["use_lora"] = False
# to-do: make this user definable
elif args.comparison:
comparison_kwargs["enabled"] = True
comparison_kwargs["suffix"] = "_entropix"
comparison_kwargs["disabled"]["use_lora"] = False
comparison_kwargs["enabled"]["use_lora"] = True
elif args.comparison == "entropix-sampling":
comparison_kwargs["suffix"] = "entropix_sampling"
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
comparison_kwargs["disabled"]["entropix_sampling"] = False
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
comparison_kwargs["disabled"]["top_k"] = args.top_k
comparison_kwargs["disabled"]["top_p"] = args.top_p
comparison_kwargs["enabled"]["entropix_sampling"] = True
comparison_kwargs["enabled"]["ar_temp"] = 0.666
comparison_kwargs["enabled"]["top_k"] = 27
comparison_kwargs["enabled"]["top_p"] = 0.9
elif args.comparison == "ar-temp":
comparison_kwargs["suffix"] = "temperature"
comparison_kwargs["titles"] = [f"Temp: {args.ar_temp:.2f}", "Temp: 1.0"]
comparison_kwargs["before"]["entropix_sampling"] = True
comparison_kwargs["before"]["ar_temp"] = 0.666
comparison_kwargs["before"]["top_k"] = 27
comparison_kwargs["before"]["top_p"] = 0.9
comparison_kwargs["after"]["entropix_sampling"] = False
comparison_kwargs["after"]["ar_temp"] = args.ar_temp
comparison_kwargs["after"]["top_k"] = args.top_k
comparison_kwargs["after"]["top_p"] = args.top_p
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
comparison_kwargs["enabled"]["ar_temp"] = 1.0
elif args.comparison == "input-prompt-length":
comparison_kwargs["suffix"] = "input_prompt_length"
comparison_kwargs["titles"] = [f"Prompt Length: {args.input_prompt_length:.2f}s", "Prompt Length: 6.0s"]
comparison_kwargs["disabled"]["input-prompt-length"] = args.input_prompt_length
comparison_kwargs["enabled"]["input-prompt-length"] = 6.0
else:
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
# read html template
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
@ -204,10 +217,9 @@ def main():
if not sample_dir.exists():
continue
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
sources = [ "ms_valle", "yourtts" ]
samples = []
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
sources = [ "ms_valle", "f5" ]
# generate demo output
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
@ -217,20 +229,21 @@ def main():
reference = dir / "reference.wav"
out_path = dir / "out" / "ours.wav"
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav"
external_sources = [ dir / "out" / f"{source}.wav" for source in sources ]
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_comparison ] if comparison_kwargs["enabled"] else [])
audio_samples = [ prompt, out_path ]
if args.comparison:
audio_samples += [ out_path_comparison ]
audio_samples += [ p for p in external_sources if p.exists() ]
if not args.random_prompts or k == "librispeech":
extra_sources += [ reference ]
audio_samples += [ reference ]
samples.append((
text,
[ prompt, out_path ] + extra_sources,
audio_samples,
))
if args.skip_existing and out_path.exists():
continue
seed = args.seed if args.seed else int(time.time())
kwargs = dict(
@ -253,19 +266,20 @@ def main():
)
def safe_inference( out_path=out_path ):
if args.skip_existing and out_path.exists():
return
try:
tts.inference( out_path=out_path, **kwargs )
except Exception as e:
print(f'Error while processing {out_path}: {e}')
if comparison_kwargs["enabled"]:
kwargs.update( comparison_kwargs["before"] )
if args.comparison:
kwargs.update( comparison_kwargs["enabled"] )
safe_inference(out_path_comparison)
kwargs.update( comparison_kwargs["after"] )
kwargs.update( comparison_kwargs["disabled"] )
safe_inference()
# collate entries into HTML
samples = [
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
@ -280,7 +294,7 @@ def main():
# write audio into template
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
if comparison_kwargs["enabled"]:
if args.comparison:
before, after = comparison_kwargs["titles"]
if args.random_prompts:
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")

View File

@ -346,14 +346,15 @@ with ui:
with gr.Row():
layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
#layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row():
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.9, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
with gr.Row():
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
if cfg.experimental:
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():
@ -394,7 +395,8 @@ with ui:
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
with gr.Row():
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
if cfg.experimental:
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
with gr.Tab("Sampler Settings"):
with gr.Row():