This commit is contained in:
mrq 2024-09-05 23:21:18 -05:00
parent 413097f5f7
commit 94cf81d38c
6 changed files with 48 additions and 4 deletions

View File

@ -13,6 +13,7 @@ def main():
parser.add_argument("text")
parser.add_argument("references", type=path_list, default=None)
parser.add_argument("--language", type=str, default="en")
parser.add_argument("--task", type=str, default="tts")
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None)
@ -53,6 +54,7 @@ def main():
text=args.text,
references=args.references,
language=args.language,
task=args.task,
out_path=args.out_path,
input_prompt_length=args.input_prompt_length,
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,

View File

@ -1053,11 +1053,13 @@ class Dataset(_Dataset):
task,
]
# Base TTS (<resp> => <text>)
# Base STT (<resp> => <text>)
elif task == "stt":
# easier to just keep it instead of wrangling around trying to remove it
# it might also help to provide a guidance prompt but who knows right now
proms = self.sample_prompts(spkr_name, ignore=path)
proms = [
task
]
# noise suppression (<text>? <resp+noise> => <resp>)
# speech removal (<text>?<resp+noise> => <noise>)

View File

@ -128,6 +128,7 @@ class TTS():
text,
references,
language="en",
task="tts",
#
max_ar_steps=6 * cfg.dataset.frames_per_second,
max_nar_levels=7,
@ -181,6 +182,40 @@ class TTS():
set_seed(seed)
if task == "stt":
resp = self.encode_audio( references )
lang = self.encode_lang( language )
reps = to_device(reps, device=self.device, dtype=torch.int16)
lang = to_device(lang, device=self.device, dtype=torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
if model_ar is not None:
text_list = model_ar(
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
sampling_top_p=top_p, sampling_top_k=top_k,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
sampling_length_penalty=length_penalty,
sampling_beam_width=beam_width,
sampling_mirostat_tau=mirostat_tau,
sampling_mirostat_eta=mirostat_eta,
sampling_dry_multiplier=dry_multiplier,
sampling_dry_base=dry_base,
sampling_dry_allowed_length=dry_allowed_length,
disable_tqdm=not tqdm,
)
else:
raise Exception("!")
text_list = [ cfg.tokenizer.decode( text ) for text in text_list ]
print( text_list )
return text_list[0]
for line in lines:
if out_path is None:
output_dir = Path("./data/results/")

View File

@ -357,11 +357,14 @@ def example_usage():
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
from ..engines import Engine, Engines
from ..utils import wrapper as ml
from ..utils import setup_logging
import numpy as np
import re
setup_logging()
device = "cuda"
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
"""

View File

@ -120,7 +120,7 @@ class TqdmLoggingHandler(logging.Handler):
self.handleError(record)
@global_leader_only
def setup_logging(log_dir: str | Path | None = "log", log_level="info"):
def setup_logging(log_dir: str | Path | None = None, log_level="info"):
handlers = []
#stdout_handler = StreamHandler()

View File

@ -13,7 +13,7 @@ from pathlib import Path
from .inference import TTS, cfg
from .train import train
from .utils import get_devices
from .utils import get_devices, setup_logging
tts = None
@ -338,6 +338,8 @@ with ui:
gr.Markdown(md)
def start( lock=True ):
setup_logging()
ui.queue(max_size=8)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)