From c74fe2f718ce988c8aebde4b965f8b05334ceaf0 Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 9 Sep 2023 22:27:20 -0500 Subject: [PATCH] tweaks to web UI --- vall_e/config.py | 3 +++ vall_e/webui.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 3f8c11a..d026e49 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -465,12 +465,15 @@ class Inference: recurrent_chunk_size: int = 0 recurrent_forward: bool = False + @cached_property def dtype(self): if self.weight_dtype == "float16": return torch.float16 if self.weight_dtype == "bfloat16": return torch.bfloat16 + if self.weight_dtype == "int8": + return torch.int8 return torch.float32 @dataclass() diff --git a/vall_e/webui.py b/vall_e/webui.py index a7cff16..cd344b7 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -4,6 +4,7 @@ import argparse import random import tempfile import functools +from datetime import datetime import gradio as gr @@ -39,7 +40,7 @@ class timer: return self def __exit__(self, type, value, traceback): - print(f'Elapsed time: {(perf_counter() - self.start):.3f}s') + print(f'[{datetime.now().isoformat()}] Elapsed time: {(perf_counter() - self.start):.3f}s') def init_tts(restart=False): global tts @@ -53,9 +54,9 @@ def init_tts(restart=False): parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--ar-ckpt", type=Path, default=None) parser.add_argument("--nar-ckpt", type=Path, default=None) - parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--amp", action="store_true") - parser.add_argument("--dtype", type=str, default="float32") + parser.add_argument("--dtype", type=str, default="float16") args, unknown = parser.parse_known_args() tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) @@ -174,7 +175,7 @@ with ui: layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference") with gr.Column(scale=7): with gr.Row(): - layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="This sets a limit of how many steps to perform in the AR pass.") + layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") with gr.Row(): layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.") @@ -184,9 +185,9 @@ with ui: layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.") layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.") with gr.Row(): - layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") - layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") - layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-40.0, maximum=4.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") + layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.") + layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.") + layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.") layout["inference"]["buttons"]["inference"].click( fn=do_inference,