tweak
This commit is contained in:
parent
413097f5f7
commit
94cf81d38c
|
@ -13,6 +13,7 @@ def main():
|
|||
parser.add_argument("text")
|
||||
parser.add_argument("references", type=path_list, default=None)
|
||||
parser.add_argument("--language", type=str, default="en")
|
||||
parser.add_argument("--task", type=str, default="tts")
|
||||
parser.add_argument("--out-path", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
|
@ -53,6 +54,7 @@ def main():
|
|||
text=args.text,
|
||||
references=args.references,
|
||||
language=args.language,
|
||||
task=args.task,
|
||||
out_path=args.out_path,
|
||||
input_prompt_length=args.input_prompt_length,
|
||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||
|
|
|
@ -1053,11 +1053,13 @@ class Dataset(_Dataset):
|
|||
task,
|
||||
]
|
||||
|
||||
# Base TTS (<resp> => <text>)
|
||||
# Base STT (<resp> => <text>)
|
||||
elif task == "stt":
|
||||
# easier to just keep it instead of wrangling around trying to remove it
|
||||
# it might also help to provide a guidance prompt but who knows right now
|
||||
proms = self.sample_prompts(spkr_name, ignore=path)
|
||||
proms = [
|
||||
task
|
||||
]
|
||||
|
||||
# noise suppression (<text>? <resp+noise> => <resp>)
|
||||
# speech removal (<text>?<resp+noise> => <noise>)
|
||||
|
|
|
@ -128,6 +128,7 @@ class TTS():
|
|||
text,
|
||||
references,
|
||||
language="en",
|
||||
task="tts",
|
||||
#
|
||||
max_ar_steps=6 * cfg.dataset.frames_per_second,
|
||||
max_nar_levels=7,
|
||||
|
@ -181,6 +182,40 @@ class TTS():
|
|||
|
||||
set_seed(seed)
|
||||
|
||||
if task == "stt":
|
||||
resp = self.encode_audio( references )
|
||||
lang = self.encode_lang( language )
|
||||
|
||||
reps = to_device(reps, device=self.device, dtype=torch.int16)
|
||||
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||
|
||||
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||
if model_ar is not None:
|
||||
text_list = model_ar(
|
||||
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps,
|
||||
sampling_temperature=ar_temp,
|
||||
sampling_min_temperature=min_ar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
sampling_length_penalty=length_penalty,
|
||||
sampling_beam_width=beam_width,
|
||||
sampling_mirostat_tau=mirostat_tau,
|
||||
sampling_mirostat_eta=mirostat_eta,
|
||||
sampling_dry_multiplier=dry_multiplier,
|
||||
sampling_dry_base=dry_base,
|
||||
sampling_dry_allowed_length=dry_allowed_length,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
||||
text_list = [ cfg.tokenizer.decode( text ) for text in text_list ]
|
||||
print( text_list )
|
||||
|
||||
return text_list[0]
|
||||
|
||||
|
||||
for line in lines:
|
||||
if out_path is None:
|
||||
output_dir = Path("./data/results/")
|
||||
|
|
|
@ -357,11 +357,14 @@ def example_usage():
|
|||
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
||||
from ..engines import Engine, Engines
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils import setup_logging
|
||||
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
setup_logging()
|
||||
device = "cuda"
|
||||
|
||||
|
||||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||
"""
|
||||
|
|
|
@ -120,7 +120,7 @@ class TqdmLoggingHandler(logging.Handler):
|
|||
self.handleError(record)
|
||||
|
||||
@global_leader_only
|
||||
def setup_logging(log_dir: str | Path | None = "log", log_level="info"):
|
||||
def setup_logging(log_dir: str | Path | None = None, log_level="info"):
|
||||
handlers = []
|
||||
|
||||
#stdout_handler = StreamHandler()
|
||||
|
|
|
@ -13,7 +13,7 @@ from pathlib import Path
|
|||
|
||||
from .inference import TTS, cfg
|
||||
from .train import train
|
||||
from .utils import get_devices
|
||||
from .utils import get_devices, setup_logging
|
||||
|
||||
tts = None
|
||||
|
||||
|
@ -338,6 +338,8 @@ with ui:
|
|||
gr.Markdown(md)
|
||||
|
||||
def start( lock=True ):
|
||||
setup_logging()
|
||||
|
||||
ui.queue(max_size=8)
|
||||
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user