fixed issue with training from scratch (oops)

This commit is contained in:
mrq 2023-10-21 09:55:38 -05:00
parent 3195026dba
commit 32d4271ca8
5 changed files with 54 additions and 12 deletions

View File

@ -482,6 +482,8 @@ class Trainer:
weight_dtype: str = "float16"
amp: bool = False
load_webui: bool = False
backend: str = "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)

View File

@ -21,6 +21,9 @@ try:
except Exception as e:
pass
from functools import cache
@cache
def load_engines():
models = get_models(cfg.models.get())
engines = dict()
@ -71,13 +74,13 @@ def load_engines():
lr_scheduler = None
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
load_path = cfg.ckpt_dir / name / "fp32.pth"
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
print("DeepSpeed checkpoint missing, but weights found.")
loads_state_dict = True
stats = None
if loads_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path, map_location=torch.device(cfg.device))
# state dict is not just the module, extract the extra trainer details

View File

@ -147,7 +147,7 @@ def run_eval(engines, eval_name, dl):
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def main():
def train():
setup_logging(cfg.log_dir)
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
@ -165,6 +165,12 @@ def main():
qnt.unload_model()
"""
if cfg.trainer.load_webui:
from .webui import start
start(lock=False)
"""
trainer.train(
train_dl=train_dl,
train_feeder=train_feeder,
@ -172,4 +178,4 @@ def main():
)
if __name__ == "__main__":
main()
train()

View File

@ -173,7 +173,8 @@ def train(
elapsed_time = stats.get("elapsed_time", 0)
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
metrics = json.dumps(stats)
_logger.info(f"Training Metrics: {metrics}.")
command = _non_blocking_input()

View File

@ -12,16 +12,18 @@ from time import perf_counter
from pathlib import Path
from .inference import TTS
from .train import train
tts = None
layout = {}
layout["inference"] = {}
layout["inference"]["inputs"] = {
"progress": None
}
layout["inference"]["outputs"] = {}
layout["inference"]["buttons"] = {}
layout["training"] = {}
for k in layout.keys():
layout[k]["inputs"] = { "progress": None }
layout[k]["outputs"] = {}
layout[k]["buttons"] = {}
# there's got to be a better way to go about this
def gradio_wrapper(inputs):
@ -123,6 +125,14 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
wav = wav.squeeze(0).cpu().numpy()
return (sr, wav)
"""
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
while True:
metrics = next(it)
yield metrics
"""
def get_random_prompt():
harvard_sentences=[
"The birch canoe slid on the smooth planks.",
@ -225,6 +235,22 @@ with ui:
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
)
"""
with gr.Tab("Training"):
with gr.Row():
with gr.Column(scale=1):
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
with gr.Row():
with gr.Column(scale=1):
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
layout["training"]["buttons"]["train"].click(
fn=do_training,
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
)
"""
if os.path.exists("README.md") and args.render_markdown:
md = open("README.md", "r", encoding="utf-8").read()
# remove HF's metadata
@ -232,5 +258,9 @@ with ui:
md = "".join(md.split("---")[2:])
gr.Markdown(md)
def start( lock=True ):
ui.queue(max_size=8)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
if __name__ == "__main__":
start()