tweak
This commit is contained in:
parent
413097f5f7
commit
94cf81d38c
|
@ -13,6 +13,7 @@ def main():
|
||||||
parser.add_argument("text")
|
parser.add_argument("text")
|
||||||
parser.add_argument("references", type=path_list, default=None)
|
parser.add_argument("references", type=path_list, default=None)
|
||||||
parser.add_argument("--language", type=str, default="en")
|
parser.add_argument("--language", type=str, default="en")
|
||||||
|
parser.add_argument("--task", type=str, default="tts")
|
||||||
parser.add_argument("--out-path", type=Path, default=None)
|
parser.add_argument("--out-path", type=Path, default=None)
|
||||||
|
|
||||||
parser.add_argument("--yaml", type=Path, default=None)
|
parser.add_argument("--yaml", type=Path, default=None)
|
||||||
|
@ -53,6 +54,7 @@ def main():
|
||||||
text=args.text,
|
text=args.text,
|
||||||
references=args.references,
|
references=args.references,
|
||||||
language=args.language,
|
language=args.language,
|
||||||
|
task=args.task,
|
||||||
out_path=args.out_path,
|
out_path=args.out_path,
|
||||||
input_prompt_length=args.input_prompt_length,
|
input_prompt_length=args.input_prompt_length,
|
||||||
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
max_ar_steps=args.max_ar_steps, max_nar_levels=args.max_nar_levels,
|
||||||
|
|
|
@ -1053,11 +1053,13 @@ class Dataset(_Dataset):
|
||||||
task,
|
task,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Base TTS (<resp> => <text>)
|
# Base STT (<resp> => <text>)
|
||||||
elif task == "stt":
|
elif task == "stt":
|
||||||
# easier to just keep it instead of wrangling around trying to remove it
|
# easier to just keep it instead of wrangling around trying to remove it
|
||||||
# it might also help to provide a guidance prompt but who knows right now
|
# it might also help to provide a guidance prompt but who knows right now
|
||||||
proms = self.sample_prompts(spkr_name, ignore=path)
|
proms = [
|
||||||
|
task
|
||||||
|
]
|
||||||
|
|
||||||
# noise suppression (<text>? <resp+noise> => <resp>)
|
# noise suppression (<text>? <resp+noise> => <resp>)
|
||||||
# speech removal (<text>?<resp+noise> => <noise>)
|
# speech removal (<text>?<resp+noise> => <noise>)
|
||||||
|
|
|
@ -128,6 +128,7 @@ class TTS():
|
||||||
text,
|
text,
|
||||||
references,
|
references,
|
||||||
language="en",
|
language="en",
|
||||||
|
task="tts",
|
||||||
#
|
#
|
||||||
max_ar_steps=6 * cfg.dataset.frames_per_second,
|
max_ar_steps=6 * cfg.dataset.frames_per_second,
|
||||||
max_nar_levels=7,
|
max_nar_levels=7,
|
||||||
|
@ -181,6 +182,40 @@ class TTS():
|
||||||
|
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
|
|
||||||
|
if task == "stt":
|
||||||
|
resp = self.encode_audio( references )
|
||||||
|
lang = self.encode_lang( language )
|
||||||
|
|
||||||
|
reps = to_device(reps, device=self.device, dtype=torch.int16)
|
||||||
|
lang = to_device(lang, device=self.device, dtype=torch.uint8)
|
||||||
|
|
||||||
|
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
||||||
|
if model_ar is not None:
|
||||||
|
text_list = model_ar(
|
||||||
|
text_list=None, proms_list=[resp], lang_list=[lang], resps_list=[resp], max_steps=max_ar_steps,
|
||||||
|
sampling_temperature=ar_temp,
|
||||||
|
sampling_min_temperature=min_ar_temp,
|
||||||
|
sampling_top_p=top_p, sampling_top_k=top_k,
|
||||||
|
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||||
|
sampling_length_penalty=length_penalty,
|
||||||
|
sampling_beam_width=beam_width,
|
||||||
|
sampling_mirostat_tau=mirostat_tau,
|
||||||
|
sampling_mirostat_eta=mirostat_eta,
|
||||||
|
sampling_dry_multiplier=dry_multiplier,
|
||||||
|
sampling_dry_base=dry_base,
|
||||||
|
sampling_dry_allowed_length=dry_allowed_length,
|
||||||
|
|
||||||
|
disable_tqdm=not tqdm,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("!")
|
||||||
|
|
||||||
|
text_list = [ cfg.tokenizer.decode( text ) for text in text_list ]
|
||||||
|
print( text_list )
|
||||||
|
|
||||||
|
return text_list[0]
|
||||||
|
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if out_path is None:
|
if out_path is None:
|
||||||
output_dir = Path("./data/results/")
|
output_dir = Path("./data/results/")
|
||||||
|
|
|
@ -357,12 +357,15 @@ def example_usage():
|
||||||
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
from ..emb.qnt import decode_to_file, unload_model, trim_random, repeat_extend_audio, concat_audio, merge_audio
|
||||||
from ..engines import Engine, Engines
|
from ..engines import Engine, Engines
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
from ..utils import setup_logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
|
|
||||||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||||
"""
|
"""
|
||||||
if "mamba" in cfg.model.arch_type:
|
if "mamba" in cfg.model.arch_type:
|
||||||
|
|
|
@ -120,7 +120,7 @@ class TqdmLoggingHandler(logging.Handler):
|
||||||
self.handleError(record)
|
self.handleError(record)
|
||||||
|
|
||||||
@global_leader_only
|
@global_leader_only
|
||||||
def setup_logging(log_dir: str | Path | None = "log", log_level="info"):
|
def setup_logging(log_dir: str | Path | None = None, log_level="info"):
|
||||||
handlers = []
|
handlers = []
|
||||||
|
|
||||||
#stdout_handler = StreamHandler()
|
#stdout_handler = StreamHandler()
|
||||||
|
|
|
@ -13,7 +13,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from .inference import TTS, cfg
|
from .inference import TTS, cfg
|
||||||
from .train import train
|
from .train import train
|
||||||
from .utils import get_devices
|
from .utils import get_devices, setup_logging
|
||||||
|
|
||||||
tts = None
|
tts = None
|
||||||
|
|
||||||
|
@ -338,6 +338,8 @@ with ui:
|
||||||
gr.Markdown(md)
|
gr.Markdown(md)
|
||||||
|
|
||||||
def start( lock=True ):
|
def start( lock=True ):
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
ui.queue(max_size=8)
|
ui.queue(max_size=8)
|
||||||
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
|
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user