This commit is contained in:
mrq 2024-09-05 23:21:18 -05:00
parent 413097f5f7
commit 94cf81d38c
6 changed files with 48 additions and 4 deletions

View File

@ -13,6 +13,7 @@ def main():
parser.add_argument("text") parser.add_argument("text")
parser.add_argument("references", type=path_list, default=None) parser.add_argument("references", type=path_list, default=None)
parser.add_argument("--language", type=str, default="en") parser.add_argument("--language", type=str, default="en")
parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
@ -53,6 +54,7 @@ def main():
text=args.text, text=args.text,
references=args.references, references=args.references,
language=args.language, language=args.language,
task=args.task,
out_path=args.out_path, out_path=args.out_path,
input_prompt_length=args.input_prompt_length, input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels, max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,

View File

@ -1053,11 +1053,13 @@ class Dataset(_Dataset):
task, task,
] ]
# Base TTS (<resp> => <text>) # Base STT (<resp> => <text>)
elif task == "stt": elif task == "stt":
# easier to just keep it instead of wrangling around trying to remove it # easier to just keep it instead of wrangling around trying to remove it
# it might also help to provide a guidance prompt but who knows right now # it might also help to provide a guidance prompt but who knows right now
proms = self.sample_prompts(spkr_name, ignore=path) proms = [
task
]
# noise suppression (<text>? <resp+noise> => <resp>) # noise suppression (<text>? <resp+noise> => <resp>)
# speech removal (<text>?<resp+noise> => <noise>) # speech removal (<text>?<resp+noise> => <noise>)

View File

@ -128,6 +128,7 @@ class TTS():
text, text,
references, references,
language="en", language="en",
task="tts",
# #
max_ar_steps=6 * cfg.dataset.frames_per_second, max_ar_steps=6 * cfg.dataset.frames_per_second,
max_nar_levels=7, max_nar_levels=7,
@ -181,6 +182,40 @@ class TTS():
set_seed(seed) set_seed(seed)
if task == "stt":
resp = self.encode_audio( references )
lang = self.encode_lang( language )
reps = to_device(reps, device=self.device, dtype=torch.int16)
lang = to_device(lang, device=self.device, dtype=torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_ar is not None:
text_list = model_ar(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm,
)
else:
raise Exception("!")
text_list = [ cfg.tokenizer.decode( text ) for text in text_list ]
print( text_list )
return text_list[0]
for line in lines: for line in lines:
if out_path is None: if out_path is None:
output_dir = Path("./data/results/") output_dir = Path("./data/results/")

View File

@ -357,12 +357,15 @@ def example_usage():
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
from ..engines import Engine, Engines from ..engines import Engine, Engines
from ..utils import wrapper as ml from ..utils import wrapper as ml
from ..utils import setup_logging
import numpy as np import numpy as np
import re import re
setup_logging()
device = "cuda" device = "cuda"
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
""" """
if "mamba" in cfg.model.arch_type: if "mamba" in cfg.model.arch_type:

View File

@ -120,7 +120,7 @@ class TqdmLoggingHandler(logging.Handler):
self.handleError(record) self.handleError(record)
@global_leader_only @global_leader_only
def setup_logging(log_dir: str | Path | None = "log", log_level="info"): def setup_logging(log_dir: str | Path | None = None, log_level="info"):
handlers = [] handlers = []
#stdout_handler = StreamHandler() #stdout_handler = StreamHandler()

View File

@ -13,7 +13,7 @@ from pathlib import Path
from .inference import TTS, cfg from .inference import TTS, cfg
from .train import train from .train import train
from .utils import get_devices from .utils import get_devices, setup_logging
tts = None tts = None
@ -338,6 +338,8 @@ with ui:
gr.Markdown(md) gr.Markdown(md)
def start( lock=True ): def start( lock=True ):
setup_logging()
ui.queue(max_size=8) ui.queue(max_size=8)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock) ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)