tweaks to web UI
This commit is contained in:
parent
7f8bd2b936
commit
c74fe2f718
|
@ -465,12 +465,15 @@ class Inference:
|
||||||
recurrent_chunk_size: int = 0
|
recurrent_chunk_size: int = 0
|
||||||
recurrent_forward: bool = False
|
recurrent_forward: bool = False
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
if self.weight_dtype == "float16":
|
if self.weight_dtype == "float16":
|
||||||
return torch.float16
|
return torch.float16
|
||||||
if self.weight_dtype == "bfloat16":
|
if self.weight_dtype == "bfloat16":
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
|
if self.weight_dtype == "int8":
|
||||||
|
return torch.int8
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
|
|
|
@ -4,6 +4,7 @@ import argparse
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import functools
|
import functools
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
@ -39,7 +40,7 @@ class timer:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
print(f'Elapsed time: {(perf_counter() - self.start):.3f}s')
|
print(f'[{datetime.now().isoformat()}] Elapsed time: {(perf_counter() - self.start):.3f}s')
|
||||||
|
|
||||||
def init_tts(restart=False):
|
def init_tts(restart=False):
|
||||||
global tts
|
global tts
|
||||||
|
@ -53,9 +54,9 @@ def init_tts(restart=False):
|
||||||
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
||||||
parser.add_argument("--ar-ckpt", type=Path, default=None)
|
parser.add_argument("--ar-ckpt", type=Path, default=None)
|
||||||
parser.add_argument("--nar-ckpt", type=Path, default=None)
|
parser.add_argument("--nar-ckpt", type=Path, default=None)
|
||||||
parser.add_argument("--device", type=str, default="cpu")
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
parser.add_argument("--amp", action="store_true")
|
parser.add_argument("--amp", action="store_true")
|
||||||
parser.add_argument("--dtype", type=str, default="float32")
|
parser.add_argument("--dtype", type=str, default="float16")
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
|
||||||
|
@ -174,7 +175,7 @@ with ui:
|
||||||
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
||||||
with gr.Column(scale=7):
|
with gr.Column(scale=7):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="This sets a limit of how many steps to perform in the AR pass.")
|
layout["inference"]["inputs"]["max-seconds"] = gr.Slider(value=6, minimum=1, maximum=32, step=0.1, label="Maximum Seconds", info="Limits how many steps to perform in the AR pass.")
|
||||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.2, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||||
|
@ -184,9 +185,9 @@ with ui:
|
||||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
|
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info="Limits the samples that are outside the top P%% of probabilities.")
|
||||||
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
layout["inference"]["inputs"]["top-k"] = gr.Slider(value=0, minimum=0, maximum=1024, step=1, label="Top K", info="Limits the samples to the top K of probabilities.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
layout["inference"]["inputs"]["repetition-penalty"] = gr.Slider(value=1.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty", info="Incurs a penalty to tokens based on how often they appear in a sequence.")
|
||||||
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-4.0, maximum=4.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
layout["inference"]["inputs"]["repetition-penalty-decay"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Repetition Penalty Length Decay", info="Modifies the reptition penalty based on how far back in time the token appeared in the sequence.")
|
||||||
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-40.0, maximum=4.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
layout["inference"]["inputs"]["length-penalty"] = gr.Slider(value=0.0, minimum=-2.0, maximum=2.0, step=0.05, label="Length Penalty", info="(AR only) Modifies the probability of a stop token based on the current length of the sequence.")
|
||||||
|
|
||||||
layout["inference"]["buttons"]["inference"].click(
|
layout["inference"]["buttons"]["inference"].click(
|
||||||
fn=do_inference,
|
fn=do_inference,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user