webui for STT (still need to bake the model to handle it better, a few hours so far has it generate what looks like a normal transcription but does not correlate to the audio right now)
This commit is contained in:
parent
d33a906119
commit
4bd9bb39c8
|
@ -186,7 +186,7 @@ class TTS():
|
||||||
resp = self.encode_audio( references )
|
resp = self.encode_audio( references )
|
||||||
lang = self.encode_lang( language )
|
lang = self.encode_lang( language )
|
||||||
|
|
||||||
reps = to_device(reps, device=self.device, dtype=torch.int16)
|
resp = to_device(resp, device=self.device, dtype=torch.int16)
|
||||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||||
|
|
||||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||||
|
|
|
@ -1419,6 +1419,8 @@ class Base(nn.Module):
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# to-do: piece-wise classification, now that there's a head for text
|
||||||
|
# although again, one single monolithic head would be preferable instead......
|
||||||
if self.classifiers is not None:
|
if self.classifiers is not None:
|
||||||
special_tasks = [ "len", "stt" ]
|
special_tasks = [ "len", "stt" ]
|
||||||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
162
vall_e/webui.py
162
vall_e/webui.py
|
@ -18,7 +18,8 @@ from .utils import get_devices, setup_logging
|
||||||
tts = None
|
tts = None
|
||||||
|
|
||||||
layout = {}
|
layout = {}
|
||||||
layout["inference"] = {}
|
layout["inference_tts"] = {}
|
||||||
|
layout["inference_stt"] = {}
|
||||||
layout["training"] = {}
|
layout["training"] = {}
|
||||||
layout["settings"] = {}
|
layout["settings"] = {}
|
||||||
|
|
||||||
|
@ -108,8 +109,8 @@ def init_tts(yaml=None, restart=False, device="cuda", dtype="auto", attention="a
|
||||||
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
tts = TTS( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp, attention=args.attention )
|
||||||
return tts
|
return tts
|
||||||
|
|
||||||
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["inference_tts"]["inputs"].keys())
|
||||||
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
if not cfg.yaml_path:
|
if not cfg.yaml_path:
|
||||||
raise Exception("No YAML loaded.")
|
raise Exception("No YAML loaded.")
|
||||||
|
|
||||||
|
@ -123,6 +124,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
# I'm very sure I can procedurally generate this list
|
# I'm very sure I can procedurally generate this list
|
||||||
parser.add_argument("--text", type=str, default=kwargs["text"])
|
parser.add_argument("--text", type=str, default=kwargs["text"])
|
||||||
|
parser.add_argument("--task", type=str, default=kwargs["task"])
|
||||||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||||
parser.add_argument("--language", type=str, default="en")
|
parser.add_argument("--language", type=str, default="en")
|
||||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||||
|
@ -159,6 +161,7 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
wav, sr = tts.inference(
|
wav, sr = tts.inference(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
language=args.language,
|
language=args.language,
|
||||||
|
task=args.task,
|
||||||
references=[args.references.split(";")] if args.references is not None else [],
|
references=[args.references.split(";")] if args.references is not None else [],
|
||||||
out_path=tmp.name,
|
out_path=tmp.name,
|
||||||
max_ar_steps=args.max_ar_steps,
|
max_ar_steps=args.max_ar_steps,
|
||||||
|
@ -183,6 +186,67 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
wav = wav.squeeze(0).cpu().numpy()
|
wav = wav.squeeze(0).cpu().numpy()
|
||||||
return (sr, wav)
|
return (sr, wav)
|
||||||
|
|
||||||
|
@gradio_wrapper(inputs=layout["inference_stt"]["inputs"].keys())
|
||||||
|
def do_inference_stt( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
|
if not cfg.yaml_path:
|
||||||
|
raise Exception("No YAML loaded.")
|
||||||
|
|
||||||
|
if kwargs.pop("dynamic-sampling", False):
|
||||||
|
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
||||||
|
else:
|
||||||
|
kwargs['min-ar-temp'] = -1
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||||
|
# I'm very sure I can procedurally generate this list
|
||||||
|
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||||
|
parser.add_argument("--language", type=str, default="en")
|
||||||
|
parser.add_argument("--max-ar-steps", type=int, default=int(cfg.dataset.frames_per_second))
|
||||||
|
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||||
|
parser.add_argument("--min-ar-temp", type=float, default=kwargs["min-ar-temp"])
|
||||||
|
parser.add_argument("--top-p", type=float, default=kwargs["top-p"])
|
||||||
|
parser.add_argument("--top-k", type=int, default=kwargs["top-k"])
|
||||||
|
parser.add_argument("--repetition-penalty", type=float, default=kwargs["repetition-penalty"])
|
||||||
|
parser.add_argument("--repetition-penalty-decay", type=float, default=kwargs["repetition-penalty-decay"])
|
||||||
|
parser.add_argument("--length-penalty", type=float, default=kwargs["length-penalty"])
|
||||||
|
parser.add_argument("--beam-width", type=int, default=kwargs["beam-width"])
|
||||||
|
parser.add_argument("--mirostat-tau", type=float, default=kwargs["mirostat-tau"])
|
||||||
|
parser.add_argument("--mirostat-eta", type=float, default=kwargs["mirostat-eta"])
|
||||||
|
parser.add_argument("--dry-multiplier", type=float, default=kwargs["dry-multiplier"])
|
||||||
|
parser.add_argument("--dry-base", type=float, default=kwargs["dry-base"])
|
||||||
|
parser.add_argument("--dry-allowed-length", type=int, default=kwargs["dry-allowed-length"])
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not args.references:
|
||||||
|
raise Exception("No reference audio provided.")
|
||||||
|
"""
|
||||||
|
|
||||||
|
tts = init_tts()
|
||||||
|
|
||||||
|
gr.Info("Inferencing...")
|
||||||
|
with timer("Inferenced in") as t:
|
||||||
|
text = tts.inference(
|
||||||
|
text="",
|
||||||
|
language=args.language,
|
||||||
|
task="stt",
|
||||||
|
references=[args.references.split(";")] if args.references is not None else [],
|
||||||
|
max_ar_steps=args.max_ar_steps,
|
||||||
|
ar_temp=args.ar_temp,
|
||||||
|
min_ar_temp=args.min_ar_temp,
|
||||||
|
top_p=args.top_p,
|
||||||
|
top_k=args.top_k,
|
||||||
|
repetition_penalty=args.repetition_penalty,
|
||||||
|
repetition_penalty_decay=args.repetition_penalty_decay,
|
||||||
|
length_penalty=args.length_penalty,
|
||||||
|
mirostat_tau=args.mirostat_tau,
|
||||||
|
mirostat_eta=args.mirostat_eta,
|
||||||
|
dry_multiplier=args.dry_multiplier,
|
||||||
|
dry_base=args.dry_base,
|
||||||
|
dry_allowed_length=args.dry_allowed_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
|
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
|
||||||
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
|
@ -255,49 +319,87 @@ if args.listen_port is not None:
|
||||||
# setup gradio
|
# setup gradio
|
||||||
ui = gr.Blocks()
|
ui = gr.Blocks()
|
||||||
with ui:
|
with ui:
|
||||||
with gr.Tab("Inference"):
|
with gr.Tab("Inference (TTS)"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=8):
|
with gr.Column(scale=8):
|
||||||
layout["inference"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
|
layout["inference_tts"]["inputs"]["text"] = gr.Textbox(lines=5, value=get_random_prompt, label="Input Prompt")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
layout["inference"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS")
|
layout["inference_tts"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS")
|
||||||
# layout["inference"]["stop"] = gr.Button(value="Stop")
|
# layout["inference_tts"]["stop"] = gr.Button(value="Stop")
|
||||||
layout["inference"]["outputs"]["output"] = gr.Audio(label="Output")
|
layout["inference_tts"]["outputs"]["output"] = gr.Audio(label="Output")
|
||||||
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
layout["inference_tts"]["buttons"]["inference"] = gr.Button(value="Inference")
|
||||||
with gr.Column(scale=7):
|
with gr.Column(scale=7):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
layout["inference_tts"]["inputs"]["max-seconds"] = gr.Slider(value=12, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||||
#layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
#layout["inference_tts"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
layout["inference_tts"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
layout["inference_tts"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
layout["inference_tts"]["inputs"]["nar-temp"] = gr.Slider(value=0.01, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR. (0 to greedy sample)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
layout["inference_tts"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
layout["inference_tts"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||||
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
layout["inference_tts"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||||
layout["inference"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
layout["inference_tts"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
layout["inference_tts"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||||
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
layout["inference_tts"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||||
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
layout["inference_tts"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
layout["inference_tts"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
||||||
layout["inference"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
layout["inference_tts"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
layout["inference_tts"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
||||||
layout["inference"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
layout["inference_tts"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
||||||
layout["inference"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
layout["inference_tts"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
||||||
|
|
||||||
layout["inference"]["buttons"]["inference"].click(
|
layout["inference_tts"]["buttons"]["inference"].click(
|
||||||
fn=do_inference,
|
fn=do_inference_tts,
|
||||||
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
|
inputs=[ x for x in layout["inference_tts"]["inputs"].values() if x is not None],
|
||||||
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
|
outputs=[ x for x in layout["inference_tts"]["outputs"].values() if x is not None]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Tab("Inference (STT)"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=8):
|
||||||
|
layout["inference_stt"]["outputs"]["ouput"] = gr.Textbox(lines=1, label="Input Prompt")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
layout["inference_stt"]["inputs"]["reference"] = gr.Audio(label="Audio Input", sources=["upload"], type="filepath") #, info="Reference audio for TTS")
|
||||||
|
# layout["inference_stt"]["stop"] = gr.Button(value="Stop")
|
||||||
|
layout["inference_stt"]["buttons"]["inference"] = gr.Button(value="Inference")
|
||||||
|
with gr.Column(scale=7):
|
||||||
|
with gr.Row():
|
||||||
|
layout["inference_stt"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR. (0 to greedy sample)")
|
||||||
|
with gr.Row():
|
||||||
|
layout["inference_stt"]["inputs"]["dynamic-sampling"] = gr.Checkbox(label="Dynamic Temperature", info="Dynamically adjusts the temperature based on the highest confident predicted token per sampling step.")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
layout["inference_stt"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||||
|
layout["inference_stt"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||||
|
layout["inference_stt"]["inputs"]["beam-width"] = gr.Slider(value=0, minimum=0, maximum=32, step=1, label="Beam Width", info="Number of branches to search through for beam search sampling.")
|
||||||
|
with gr.Row():
|
||||||
|
layout["inference_stt"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||||
|
layout["inference_stt"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||||
|
layout["inference_stt"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||||
|
with gr.Row():
|
||||||
|
layout["inference_stt"]["inputs"]["mirostat-tau"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="Mirostat τ (Tau)", info="The \"surprise\" value when performing mirostat sampling. 0 to disable.")
|
||||||
|
layout["inference_stt"]["inputs"]["mirostat-eta"] = gr.Slider(value=0.0, minimum=0.0, maximum=2.0, step=0.05, label="Mirostat η (Eta)", info="The \"learning rate\" during mirostat sampling applied to the maximum surprise.")
|
||||||
|
with gr.Row():
|
||||||
|
layout["inference_stt"]["inputs"]["dry-multiplier"] = gr.Slider(value=0.0, minimum=0.0, maximum=8.0, step=0.05, label="DRY Multiplier", info="The multiplying factor for the DRY score penalty (0 to disable DRY sampling).")
|
||||||
|
layout["inference_stt"]["inputs"]["dry-base"] = gr.Slider(value=1.75, minimum=0.0, maximum=8.0, step=0.05, label="DRY Base", info="The base of the exponent in the DRY score penalty")
|
||||||
|
layout["inference_stt"]["inputs"]["dry-allowed-length"] = gr.Slider(value=2, minimum=0, maximum=75, step=1, label="Allowed Length", info="The maximimum length a token can be to perform DRY penalty with.")
|
||||||
|
|
||||||
|
layout["inference_stt"]["buttons"]["inference"].click(
|
||||||
|
fn=do_inference_stt,
|
||||||
|
inputs=[ x for x in layout["inference_stt"]["inputs"].values() if x is not None],
|
||||||
|
outputs=[ x for x in layout["inference_stt"]["outputs"].values() if x is not None]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with gr.Tab("Training"):
|
with gr.Tab("Training"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user