diff --git a/vall_e/config.py b/vall_e/config.py index 7bce83d..606f90e 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -482,6 +482,8 @@ class Trainer: weight_dtype: str = "float16" amp: bool = False + load_webui: bool = False + backend: str = "local" deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index faf0e32..60e111c 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -21,6 +21,9 @@ try: except Exception as e: pass +from functools import cache + +@cache def load_engines(): models = get_models(cfg.models.get()) engines = dict() @@ -71,13 +74,13 @@ def load_engines(): lr_scheduler = None # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present - if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists(): + load_path = cfg.ckpt_dir / name / "fp32.pth" + if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists(): print("DeepSpeed checkpoint missing, but weights found.") loads_state_dict = True stats = None if loads_state_dict: - load_path = cfg.ckpt_dir / name / "fp32.pth" state = torch.load(load_path, map_location=torch.device(cfg.device)) # state dict is not just the module, extract the extra trainer details diff --git a/vall_e/train.py b/vall_e/train.py index e7821a5..dda58da 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -147,7 +147,7 @@ def run_eval(engines, eval_name, dl): _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") -def main(): +def train(): setup_logging(cfg.log_dir) train_dl, subtrain_dl, val_dl = create_train_val_dataloader() @@ -165,6 +165,12 @@ def main(): qnt.unload_model() + """ + if cfg.trainer.load_webui: + from .webui import start + start(lock=False) + """ + trainer.train( train_dl=train_dl, train_feeder=train_feeder, @@ -172,4 +178,4 @@ def main(): ) if __name__ == "__main__": - main() + train() diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index d9bf9e0..d1afd84 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -173,7 +173,8 @@ def train( elapsed_time = stats.get("elapsed_time", 0) - _logger.info(f"Training Metrics: {json.dumps(stats)}.") + metrics = json.dumps(stats) + _logger.info(f"Training Metrics: {metrics}.") command = _non_blocking_input() diff --git a/vall_e/webui.py b/vall_e/webui.py index 481b11e..5b55cbf 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -12,16 +12,18 @@ from time import perf_counter from pathlib import Path from .inference import TTS +from .train import train tts = None layout = {} layout["inference"] = {} -layout["inference"]["inputs"] = { - "progress": None -} -layout["inference"]["outputs"] = {} -layout["inference"]["buttons"] = {} +layout["training"] = {} + +for k in layout.keys(): + layout[k]["inputs"] = { "progress": None } + layout[k]["outputs"] = {} + layout[k]["buttons"] = {} # there's got to be a better way to go about this def gradio_wrapper(inputs): @@ -123,6 +125,14 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): wav = wav.squeeze(0).cpu().numpy() return (sr, wav) +""" +@gradio_wrapper(inputs=layout["training"]["inputs"].keys()) +def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): + while True: + metrics = next(it) + yield metrics +""" + def get_random_prompt(): harvard_sentences=[ "The birch canoe slid on the smooth planks.", @@ -225,6 +235,22 @@ with ui: inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None], outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None] ) + + """ + with gr.Tab("Training"): + with gr.Row(): + with gr.Column(scale=1): + layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log") + with gr.Row(): + with gr.Column(scale=1): + layout["training"]["buttons"]["train"] = gr.Button(value="Train") + + layout["training"]["buttons"]["train"].click( + fn=do_training, + outputs=[ x for x in layout["training"]["outputs"].values() if x is not None], + ) + """ + if os.path.exists("README.md") and args.render_markdown: md = open("README.md", "r", encoding="utf-8").read() # remove HF's metadata @@ -232,5 +258,9 @@ with ui: md = "".join(md.split("---")[2:]) gr.Markdown(md) -ui.queue(max_size=8) -ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port) \ No newline at end of file +def start( lock=True ): + ui.queue(max_size=8) + ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock) + +if __name__ == "__main__": + start() \ No newline at end of file