tweaks to web UI
This commit is contained in:
parent
7f8bd2b936
commit
c74fe2f718
|
@ -465,12 +465,15 @@ class Inference:
|
|||
recurrent_chunk_size: int = 0
|
||||
recurrent_forward: bool = False
|
||||
|
||||
|
||||
@cached_property
|
||||
def dtype(self):
|
||||
if self.weight_dtype == "float16":
|
||||
return torch.float16
|
||||
if self.weight_dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if self.weight_dtype == "int8":
|
||||
return torch.int8
|
||||
return torch.float32
|
||||
|
||||
@dataclass()
|
||||
|
|
|
@ -4,6 +4,7 @@ import argparse
|
|||
import random
|
||||
import tempfile
|
||||
import functools
|
||||
from datetime import datetime
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
@ -39,7 +40,7 @@ class timer:
|
|||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
print(f'Elapsed time: {(perf_counter() - self.start):.3f}s')
|
||||
print(f'[{datetime.now().isoformat()}] Elapsed time: {(perf_counter() - self.start):.3f}s')
|
||||
|
||||
def init_tts(restart=False):
|
||||
global tts
|
||||
|
@ -53,9 +54,9 @@ def init_tts(restart=False):
|
|||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||
parser.add_argument("--ar-ckpt", type=Path, default=None)
|
||||
parser.add_argument("--nar-ckpt", type=Path, default=None)
|
||||
parser.add_argument("--device", type=str, default="cpu")
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
parser.add_argument("--amp", action="store_true")
|
||||
parser.add_argument("--dtype", type=str, default="float32")
|
||||
parser.add_argument("--dtype", type=str, default="float16")
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||
|
@ -174,7 +175,7 @@ with ui:
|
|||
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
||||
with gr.Column(scale=7):
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="This sets a limit of how many steps to perform in the AR pass.")
|
||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||
|
@ -184,9 +185,9 @@ with ui:
|
|||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
|
||||
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-40.0, maximum=4.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||
|
||||
layout["inference"]["buttons"]["inference"].click(
|
||||
fn=do_inference,
|
||||
|
|
Loading…
Reference in New Issue
Block a user