oops
This commit is contained in:
parent
75b90be325
commit
0dfab973e7
|
@ -167,6 +167,7 @@ class Dataset:
|
||||||
prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt
|
prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt
|
||||||
prompt_similar_top_k: int = 1 # top-k similar candidates to sample from
|
prompt_similar_top_k: int = 1 # top-k similar candidates to sample from
|
||||||
prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from
|
prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from
|
||||||
|
prompt_inject_noise: bool = False # adds noise to the input prompt waveform to try and vary things
|
||||||
|
|
||||||
resps_max_samples: int = 1 # number of samples to target for training
|
resps_max_samples: int = 1 # number of samples to target for training
|
||||||
resps_append_p: float = 1.0 # probability to append another sample to the training target
|
resps_append_p: float = 1.0 # probability to append another sample to the training target
|
||||||
|
@ -176,7 +177,6 @@ class Dataset:
|
||||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||||
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||||
noise_scale: float = 0.25 # scaling noise value
|
noise_scale: float = 0.25 # scaling noise value
|
||||||
noise_inject_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
|
|
||||||
retokenize_text: bool = False
|
retokenize_text: bool = False
|
||||||
|
|
||||||
_frames_per_second: int = 0 # allows setting your own hint
|
_frames_per_second: int = 0 # allows setting your own hint
|
||||||
|
|
|
@ -1010,7 +1010,8 @@ class Dataset(_Dataset):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
|
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
|
||||||
|
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) if trim else 0
|
||||||
|
|
||||||
for _ in range(cfg.dataset.prompt_max_samples):
|
for _ in range(cfg.dataset.prompt_max_samples):
|
||||||
if reference is not None:
|
if reference is not None:
|
||||||
|
@ -1142,7 +1143,7 @@ class Dataset(_Dataset):
|
||||||
if task == "tts":
|
if task == "tts":
|
||||||
proms = self.sample_prompts(spkr_name, reference=path)
|
proms = self.sample_prompts(spkr_name, reference=path)
|
||||||
|
|
||||||
if cfg.dataset.inject_noise_in_prom:
|
if cfg.dataset.prompt_inject_noise:
|
||||||
# sample random noise
|
# sample random noise
|
||||||
noise = self.sample_noise()
|
noise = self.sample_noise()
|
||||||
# extend the noise to fill the target audio
|
# extend the noise to fill the target audio
|
||||||
|
@ -1156,7 +1157,8 @@ class Dataset(_Dataset):
|
||||||
elif task == "tts-c":
|
elif task == "tts-c":
|
||||||
# trim a piece of the output response
|
# trim a piece of the output response
|
||||||
if naive:
|
if naive:
|
||||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
|
||||||
|
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second)
|
||||||
|
|
||||||
proms = resps[:trim_length, :]
|
proms = resps[:trim_length, :]
|
||||||
resps = resps[trim_length:, :]
|
resps = resps[trim_length:, :]
|
||||||
|
|
|
@ -87,7 +87,7 @@ def main():
|
||||||
|
|
||||||
parser.add_argument("--random-prompts", action="store_true")
|
parser.add_argument("--random-prompts", action="store_true")
|
||||||
parser.add_argument("--lora", action="store_true")
|
parser.add_argument("--lora", action="store_true")
|
||||||
parser.add_argument("--comparison", action="store_true")
|
parser.add_argument("--comparison", type=str, default=None)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -104,34 +104,47 @@ def main():
|
||||||
|
|
||||||
# comparison kwargs
|
# comparison kwargs
|
||||||
comparison_kwargs = {
|
comparison_kwargs = {
|
||||||
"enabled": False,
|
|
||||||
"titles": [],
|
"titles": [],
|
||||||
"suffix": "_after",
|
"suffix": "diff",
|
||||||
"before": {},
|
"enabled": {},
|
||||||
"after": {}
|
"disabled": {}
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.lora:
|
if args.lora:
|
||||||
comparison_kwargs["enabled"] = True
|
args.comparison = "lora"
|
||||||
comparison_kwargs["suffix"] = "_lora"
|
|
||||||
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
|
|
||||||
comparison_kwargs["before"]["use_lora"] = True
|
|
||||||
comparison_kwargs["after"]["use_lora"] = False
|
|
||||||
# to-do: make this user definable
|
|
||||||
elif args.comparison:
|
|
||||||
comparison_kwargs["enabled"] = True
|
|
||||||
comparison_kwargs["suffix"] = "_entropix"
|
|
||||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
|
||||||
|
|
||||||
comparison_kwargs["before"]["entropix_sampling"] = True
|
|
||||||
comparison_kwargs["before"]["ar_temp"] = 0.666
|
|
||||||
comparison_kwargs["before"]["top_k"] = 27
|
|
||||||
comparison_kwargs["before"]["top_p"] = 0.9
|
|
||||||
comparison_kwargs["after"]["entropix_sampling"] = False
|
|
||||||
comparison_kwargs["after"]["ar_temp"] = args.ar_temp
|
|
||||||
comparison_kwargs["after"]["top_k"] = args.top_k
|
|
||||||
comparison_kwargs["after"]["top_p"] = args.top_p
|
|
||||||
|
|
||||||
|
# to-do: just make this mappable
|
||||||
|
if args.comparison == "lora":
|
||||||
|
comparison_kwargs["suffix"] = "lora"
|
||||||
|
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
|
||||||
|
|
||||||
|
comparison_kwargs["disabled"]["use_lora"] = False
|
||||||
|
comparison_kwargs["enabled"]["use_lora"] = True
|
||||||
|
elif args.comparison == "entropix-sampling":
|
||||||
|
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||||
|
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||||
|
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
||||||
|
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
||||||
|
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
||||||
|
comparison_kwargs["disabled"]["top_p"] = args.top_p
|
||||||
|
comparison_kwargs["enabled"]["entropix_sampling"] = True
|
||||||
|
comparison_kwargs["enabled"]["ar_temp"] = 0.666
|
||||||
|
comparison_kwargs["enabled"]["top_k"] = 27
|
||||||
|
comparison_kwargs["enabled"]["top_p"] = 0.9
|
||||||
|
elif args.comparison == "ar-temp":
|
||||||
|
comparison_kwargs["suffix"] = "temperature"
|
||||||
|
comparison_kwargs["titles"] = [f"Temp: {args.ar_temp:.2f}", "Temp: 1.0"]
|
||||||
|
|
||||||
|
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
||||||
|
comparison_kwargs["enabled"]["ar_temp"] = 1.0
|
||||||
|
elif args.comparison == "input-prompt-length":
|
||||||
|
comparison_kwargs["suffix"] = "input_prompt_length"
|
||||||
|
comparison_kwargs["titles"] = [f"Prompt Length: {args.input_prompt_length:.2f}s", "Prompt Length: 6.0s"]
|
||||||
|
|
||||||
|
comparison_kwargs["disabled"]["input-prompt-length"] = args.input_prompt_length
|
||||||
|
comparison_kwargs["enabled"]["input-prompt-length"] = 6.0
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
||||||
|
|
||||||
# read html template
|
# read html template
|
||||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||||
|
@ -204,10 +217,9 @@ def main():
|
||||||
if not sample_dir.exists():
|
if not sample_dir.exists():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
|
||||||
sources = [ "ms_valle", "yourtts" ]
|
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
|
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||||
|
sources = [ "ms_valle", "f5" ]
|
||||||
|
|
||||||
# generate demo output
|
# generate demo output
|
||||||
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
|
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
|
||||||
|
@ -217,20 +229,21 @@ def main():
|
||||||
reference = dir / "reference.wav"
|
reference = dir / "reference.wav"
|
||||||
out_path = dir / "out" / "ours.wav"
|
out_path = dir / "out" / "ours.wav"
|
||||||
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav"
|
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav"
|
||||||
|
external_sources = [ dir / "out" / f"{source}.wav" for source in sources ]
|
||||||
|
|
||||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_comparison ] if comparison_kwargs["enabled"] else [])
|
audio_samples = [ prompt, out_path ]
|
||||||
|
if args.comparison:
|
||||||
|
audio_samples += [ out_path_comparison ]
|
||||||
|
audio_samples += [ p for p in external_sources if p.exists() ]
|
||||||
|
|
||||||
if not args.random_prompts or k == "librispeech":
|
if not args.random_prompts or k == "librispeech":
|
||||||
extra_sources += [ reference ]
|
audio_samples += [ reference ]
|
||||||
|
|
||||||
samples.append((
|
samples.append((
|
||||||
text,
|
text,
|
||||||
[ prompt, out_path ] + extra_sources,
|
audio_samples,
|
||||||
))
|
))
|
||||||
|
|
||||||
if args.skip_existing and out_path.exists():
|
|
||||||
continue
|
|
||||||
|
|
||||||
seed = args.seed if args.seed else int(time.time())
|
seed = args.seed if args.seed else int(time.time())
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
|
@ -253,19 +266,20 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
def safe_inference( out_path=out_path ):
|
def safe_inference( out_path=out_path ):
|
||||||
|
if args.skip_existing and out_path.exists():
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
tts.inference( out_path=out_path, **kwargs )
|
tts.inference( out_path=out_path, **kwargs )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Error while processing {out_path}: {e}')
|
print(f'Error while processing {out_path}: {e}')
|
||||||
|
|
||||||
if comparison_kwargs["enabled"]:
|
if args.comparison:
|
||||||
kwargs.update( comparison_kwargs["before"] )
|
kwargs.update( comparison_kwargs["enabled"] )
|
||||||
safe_inference(out_path_comparison)
|
safe_inference(out_path_comparison)
|
||||||
kwargs.update( comparison_kwargs["after"] )
|
kwargs.update( comparison_kwargs["disabled"] )
|
||||||
|
|
||||||
safe_inference()
|
safe_inference()
|
||||||
|
|
||||||
|
|
||||||
# collate entries into HTML
|
# collate entries into HTML
|
||||||
samples = [
|
samples = [
|
||||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||||
|
@ -280,7 +294,7 @@ def main():
|
||||||
# write audio into template
|
# write audio into template
|
||||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||||
|
|
||||||
if comparison_kwargs["enabled"]:
|
if args.comparison:
|
||||||
before, after = comparison_kwargs["titles"]
|
before, after = comparison_kwargs["titles"]
|
||||||
if args.random_prompts:
|
if args.random_prompts:
|
||||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
||||||
|
|
|
@ -346,14 +346,15 @@ with ui:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||||
#layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
#layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||||
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.9, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.9, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||||
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||||
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
if cfg.experimental:
|
||||||
|
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||||
with gr.Tab("Sampler Settings"):
|
with gr.Tab("Sampler Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
@ -394,7 +395,8 @@ with ui:
|
||||||
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||||
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
if cfg.experimental:
|
||||||
|
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||||
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||||
with gr.Tab("Sampler Settings"):
|
with gr.Tab("Sampler Settings"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user