oops
This commit is contained in:
parent
75b90be325
commit
0dfab973e7
|
@ -167,6 +167,7 @@ class Dataset:
|
|||
prompt_similar_p: float = 0.75 # odds of sampling for a similar prompt instead of a random prompt
|
||||
prompt_similar_top_k: int = 1 # top-k similar candidates to sample from
|
||||
prompt_similar_top_k_offset: int = 0 # offset from the top-k to sample from
|
||||
prompt_inject_noise: bool = False # adds noise to the input prompt waveform to try and vary things
|
||||
|
||||
resps_max_samples: int = 1 # number of samples to target for training
|
||||
resps_append_p: float = 1.0 # probability to append another sample to the training target
|
||||
|
@ -176,7 +177,6 @@ class Dataset:
|
|||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||
reencode_device: str = "cpu" # "cpu" is slower but saves memory, cuda throws [rank0]: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
|
||||
noise_scale: float = 0.25 # scaling noise value
|
||||
noise_inject_in_prom: bool = False # adds noise to the input prompt waveform to try and vary things
|
||||
retokenize_text: bool = False
|
||||
|
||||
_frames_per_second: int = 0 # allows setting your own hint
|
||||
|
|
|
@ -1010,7 +1010,8 @@ class Dataset(_Dataset):
|
|||
"""
|
||||
|
||||
prom_length = 0
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) if trim else 0
|
||||
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
|
||||
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second) if trim else 0
|
||||
|
||||
for _ in range(cfg.dataset.prompt_max_samples):
|
||||
if reference is not None:
|
||||
|
@ -1142,7 +1143,7 @@ class Dataset(_Dataset):
|
|||
if task == "tts":
|
||||
proms = self.sample_prompts(spkr_name, reference=path)
|
||||
|
||||
if cfg.dataset.inject_noise_in_prom:
|
||||
if cfg.dataset.prompt_inject_noise:
|
||||
# sample random noise
|
||||
noise = self.sample_noise()
|
||||
# extend the noise to fill the target audio
|
||||
|
@ -1156,7 +1157,8 @@ class Dataset(_Dataset):
|
|||
elif task == "tts-c":
|
||||
# trim a piece of the output response
|
||||
if naive:
|
||||
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
|
||||
duration_lo, duration_hi = cfg.dataset.prompt_duration_range
|
||||
trim_length = int(random.uniform(duration_lo, duration_hi) * cfg.dataset.frames_per_second)
|
||||
|
||||
proms = resps[:trim_length, :]
|
||||
resps = resps[trim_length:, :]
|
||||
|
|
|
@ -87,7 +87,7 @@ def main():
|
|||
|
||||
parser.add_argument("--random-prompts", action="store_true")
|
||||
parser.add_argument("--lora", action="store_true")
|
||||
parser.add_argument("--comparison", action="store_true")
|
||||
parser.add_argument("--comparison", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -104,34 +104,47 @@ def main():
|
|||
|
||||
# comparison kwargs
|
||||
comparison_kwargs = {
|
||||
"enabled": False,
|
||||
"titles": [],
|
||||
"suffix": "_after",
|
||||
"before": {},
|
||||
"after": {}
|
||||
"suffix": "diff",
|
||||
"enabled": {},
|
||||
"disabled": {}
|
||||
}
|
||||
|
||||
if args.lora:
|
||||
comparison_kwargs["enabled"] = True
|
||||
comparison_kwargs["suffix"] = "_lora"
|
||||
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
|
||||
comparison_kwargs["before"]["use_lora"] = True
|
||||
comparison_kwargs["after"]["use_lora"] = False
|
||||
# to-do: make this user definable
|
||||
elif args.comparison:
|
||||
comparison_kwargs["enabled"] = True
|
||||
comparison_kwargs["suffix"] = "_entropix"
|
||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||
|
||||
comparison_kwargs["before"]["entropix_sampling"] = True
|
||||
comparison_kwargs["before"]["ar_temp"] = 0.666
|
||||
comparison_kwargs["before"]["top_k"] = 27
|
||||
comparison_kwargs["before"]["top_p"] = 0.9
|
||||
comparison_kwargs["after"]["entropix_sampling"] = False
|
||||
comparison_kwargs["after"]["ar_temp"] = args.ar_temp
|
||||
comparison_kwargs["after"]["top_k"] = args.top_k
|
||||
comparison_kwargs["after"]["top_p"] = args.top_p
|
||||
args.comparison = "lora"
|
||||
|
||||
# to-do: just make this mappable
|
||||
if args.comparison == "lora":
|
||||
comparison_kwargs["suffix"] = "lora"
|
||||
comparison_kwargs["titles"] = ["No LoRA", "LoRA"]
|
||||
|
||||
comparison_kwargs["disabled"]["use_lora"] = False
|
||||
comparison_kwargs["enabled"]["use_lora"] = True
|
||||
elif args.comparison == "entropix-sampling":
|
||||
comparison_kwargs["suffix"] = "entropix_sampling"
|
||||
comparison_kwargs["titles"] = ["Without Entropix", "With Entropix"]
|
||||
comparison_kwargs["disabled"]["entropix_sampling"] = False
|
||||
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
||||
comparison_kwargs["disabled"]["top_k"] = args.top_k
|
||||
comparison_kwargs["disabled"]["top_p"] = args.top_p
|
||||
comparison_kwargs["enabled"]["entropix_sampling"] = True
|
||||
comparison_kwargs["enabled"]["ar_temp"] = 0.666
|
||||
comparison_kwargs["enabled"]["top_k"] = 27
|
||||
comparison_kwargs["enabled"]["top_p"] = 0.9
|
||||
elif args.comparison == "ar-temp":
|
||||
comparison_kwargs["suffix"] = "temperature"
|
||||
comparison_kwargs["titles"] = [f"Temp: {args.ar_temp:.2f}", "Temp: 1.0"]
|
||||
|
||||
comparison_kwargs["disabled"]["ar_temp"] = args.ar_temp
|
||||
comparison_kwargs["enabled"]["ar_temp"] = 1.0
|
||||
elif args.comparison == "input-prompt-length":
|
||||
comparison_kwargs["suffix"] = "input_prompt_length"
|
||||
comparison_kwargs["titles"] = [f"Prompt Length: {args.input_prompt_length:.2f}s", "Prompt Length: 6.0s"]
|
||||
|
||||
comparison_kwargs["disabled"]["input-prompt-length"] = args.input_prompt_length
|
||||
comparison_kwargs["enabled"]["input-prompt-length"] = 6.0
|
||||
else:
|
||||
raise Exception(f"Unrecognized comparison flag: {args.comparison}")
|
||||
|
||||
# read html template
|
||||
html = open(args.demo_dir / "index.template.html", "r", encoding="utf-8").read()
|
||||
|
@ -204,10 +217,9 @@ def main():
|
|||
if not sample_dir.exists():
|
||||
continue
|
||||
|
||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||
sources = [ "ms_valle", "yourtts" ]
|
||||
|
||||
samples = []
|
||||
speakers = [ dir for dir in sample_dir.iterdir() if dir.is_dir() ]
|
||||
sources = [ "ms_valle", "f5" ]
|
||||
|
||||
# generate demo output
|
||||
for dir in tqdm(speakers, desc=f"Generating demo for {k}"):
|
||||
|
@ -217,20 +229,21 @@ def main():
|
|||
reference = dir / "reference.wav"
|
||||
out_path = dir / "out" / "ours.wav"
|
||||
out_path_comparison = dir / "out" / f"ours_{comparison_kwargs["suffix"]}.wav"
|
||||
external_sources = [ dir / "out" / f"{source}.wav" for source in sources ]
|
||||
|
||||
extra_sources = [ dir / "out" / f"{source}.wav" for source in sources ] if k == "librispeech" else ([ out_path_comparison ] if comparison_kwargs["enabled"] else [])
|
||||
audio_samples = [ prompt, out_path ]
|
||||
if args.comparison:
|
||||
audio_samples += [ out_path_comparison ]
|
||||
audio_samples += [ p for p in external_sources if p.exists() ]
|
||||
|
||||
if not args.random_prompts or k == "librispeech":
|
||||
extra_sources += [ reference ]
|
||||
audio_samples += [ reference ]
|
||||
|
||||
samples.append((
|
||||
text,
|
||||
[ prompt, out_path ] + extra_sources,
|
||||
audio_samples,
|
||||
))
|
||||
|
||||
if args.skip_existing and out_path.exists():
|
||||
continue
|
||||
|
||||
seed = args.seed if args.seed else int(time.time())
|
||||
|
||||
kwargs = dict(
|
||||
|
@ -253,19 +266,20 @@ def main():
|
|||
)
|
||||
|
||||
def safe_inference( out_path=out_path ):
|
||||
if args.skip_existing and out_path.exists():
|
||||
return
|
||||
try:
|
||||
tts.inference( out_path=out_path, **kwargs )
|
||||
except Exception as e:
|
||||
print(f'Error while processing {out_path}: {e}')
|
||||
|
||||
if comparison_kwargs["enabled"]:
|
||||
kwargs.update( comparison_kwargs["before"] )
|
||||
if args.comparison:
|
||||
kwargs.update( comparison_kwargs["enabled"] )
|
||||
safe_inference(out_path_comparison)
|
||||
kwargs.update( comparison_kwargs["after"] )
|
||||
kwargs.update( comparison_kwargs["disabled"] )
|
||||
|
||||
safe_inference()
|
||||
|
||||
|
||||
# collate entries into HTML
|
||||
samples = [
|
||||
f'\n\t\t\t<tr>\n\t\t\t\t<td>{text}</td>'+
|
||||
|
@ -280,7 +294,7 @@ def main():
|
|||
# write audio into template
|
||||
html = html.replace("${"+k.upper()+"_SAMPLES}", "\n".join( samples ) )
|
||||
|
||||
if comparison_kwargs["enabled"]:
|
||||
if args.comparison:
|
||||
before, after = comparison_kwargs["titles"]
|
||||
if args.random_prompts:
|
||||
html = html.replace("<th>Our VALL-E</th>\n\t\t\t\t\t<th>Ground Truth</th>", f"<th>Our VALL-E ({before})</th>\n\t\t\t\t\t<th>Our VALL-E ({after})</th>")
|
||||
|
|
|
@ -346,14 +346,15 @@ with ui:
|
|||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||
#layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=5.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Repeat/Trim Length", info="Repeats and trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
with gr.Row():
|
||||
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.9, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||
with gr.Row():
|
||||
#layout["inference_tts"]["inputs"]["input-prompt-prefix"] = gr.Checkbox(label="Input Prompt as Prefix", info="Treats the input prompt clip as the prefix of the generated sequence.")
|
||||
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||
if cfg.experimental:
|
||||
layout["inference_tts"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||
layout["inference_tts"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
|
@ -394,7 +395,8 @@ with ui:
|
|||
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.0, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||
with gr.Row():
|
||||
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||
if cfg.experimental:
|
||||
layout["inference_stt"]["inputs"]["entropix-sampling"] = gr.Checkbox(label="Entropix Sampling", info="Dynamically samples based on entropy/varentropy values from the logits / attention scores.")
|
||||
layout["inference_stt"]["inputs"]["language"] = gr.Dropdown(choices=get_languages(), label="Language", value="en")
|
||||
with gr.Tab("Sampler Settings"):
|
||||
with gr.Row():
|
||||
|
|
Loading…
Reference in New Issue
Block a user