tweaks to web UI

This commit is contained in:
mrq 2023-09-09 22:27:20 -05:00
parent 7f8bd2b936
commit c74fe2f718
2 changed files with 11 additions and 7 deletions

View File

@ -465,12 +465,15 @@ class Inference:
recurrent_chunk_size: int = 0
recurrent_forward: bool = False
@cached_property
def dtype(self):
if self.weight_dtype == "float16":
return torch.float16
if self.weight_dtype == "bfloat16":
return torch.bfloat16
if self.weight_dtype == "int8":
return torch.int8
return torch.float32
@dataclass()

View File

@ -4,6 +4,7 @@ import argparse
import random
import tempfile
import functools
from datetime import datetime
import gradio as gr
@ -39,7 +40,7 @@ class timer:
return self
def __exit__(self, type, value, traceback):
print(f'Elapsed time: {(perf_counter() - self.start):.3f}s')
print(f'[{datetime.now().isoformat()}] Elapsed time: {(perf_counter() - self.start):.3f}s')
def init_tts(restart=False):
global tts
@ -53,9 +54,9 @@ def init_tts(restart=False):
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--ar-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None)
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default="float32")
parser.add_argument("--dtype", type=str, default="float16")
args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
@ -174,7 +175,7 @@ with ui:
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
with gr.Column(scale=7):
with gr.Row():
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="This sets a limit of how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
@ -184,9 +185,9 @@ with ui:
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
with gr.Row():
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-40.0, maximum=4.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
layout["inference"]["buttons"]["inference"].click(
fn=do_inference,