fixed issue with training from scratch (oops)

This commit is contained in:
mrq 2023-10-21 09:55:38 -05:00
parent 3195026dba
commit 32d4271ca8
5 changed files with 54 additions and 12 deletions

View File

@ -482,6 +482,8 @@ class Trainer:
weight_dtype: str = "float16" weight_dtype: str = "float16"
amp: bool = False amp: bool = False
load_webui: bool = False
backend: str = "local" backend: str = "local"
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed) deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)

View File

@ -21,6 +21,9 @@ try:
except Exception as e: except Exception as e:
pass pass
from functools import cache
@cache
def load_engines(): def load_engines():
models = get_models(cfg.models.get()) models = get_models(cfg.models.get())
engines = dict() engines = dict()
@ -71,13 +74,13 @@ def load_engines():
lr_scheduler = None lr_scheduler = None
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists(): load_path = cfg.ckpt_dir / name / "fp32.pth"
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
print("DeepSpeed checkpoint missing, but weights found.") print("DeepSpeed checkpoint missing, but weights found.")
loads_state_dict = True loads_state_dict = True
stats = None stats = None
if loads_state_dict: if loads_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path, map_location=torch.device(cfg.device)) state = torch.load(load_path, map_location=torch.device(cfg.device))
# state dict is not just the module, extract the extra trainer details # state dict is not just the module, extract the extra trainer details

View File

@ -147,7 +147,7 @@ def run_eval(engines, eval_name, dl):
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def main(): def train():
setup_logging(cfg.log_dir) setup_logging(cfg.log_dir)
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
@ -165,6 +165,12 @@ def main():
qnt.unload_model() qnt.unload_model()
"""
if cfg.trainer.load_webui:
from .webui import start
start(lock=False)
"""
trainer.train( trainer.train(
train_dl=train_dl, train_dl=train_dl,
train_feeder=train_feeder, train_feeder=train_feeder,
@ -172,4 +178,4 @@ def main():
) )
if __name__ == "__main__": if __name__ == "__main__":
main() train()

View File

@ -173,7 +173,8 @@ def train(
elapsed_time = stats.get("elapsed_time", 0) elapsed_time = stats.get("elapsed_time", 0)
_logger.info(f"Training Metrics: {json.dumps(stats)}.") metrics = json.dumps(stats)
_logger.info(f"Training Metrics: {metrics}.")
command = _non_blocking_input() command = _non_blocking_input()

View File

@ -12,16 +12,18 @@ from time import perf_counter
from pathlib import Path from pathlib import Path
from .inference import TTS from .inference import TTS
from .train import train
tts = None tts = None
layout = {} layout = {}
layout["inference"] = {} layout["inference"] = {}
layout["inference"]["inputs"] = { layout["training"] = {}
"progress": None
} for k in layout.keys():
layout["inference"]["outputs"] = {} layout[k]["inputs"] = { "progress": None }
layout["inference"]["buttons"] = {} layout[k]["outputs"] = {}
layout[k]["buttons"] = {}
# there's got to be a better way to go about this # there's got to be a better way to go about this
def gradio_wrapper(inputs): def gradio_wrapper(inputs):
@ -123,6 +125,14 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
wav = wav.squeeze(0).cpu().numpy() wav = wav.squeeze(0).cpu().numpy()
return (sr, wav) return (sr, wav)
"""
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
while True:
metrics = next(it)
yield metrics
"""
def get_random_prompt(): def get_random_prompt():
harvard_sentences=[ harvard_sentences=[
"The birch canoe slid on the smooth planks.", "The birch canoe slid on the smooth planks.",
@ -225,6 +235,22 @@ with ui:
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None], inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None] outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
) )
"""
with gr.Tab("Training"):
with gr.Row():
with gr.Column(scale=1):
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
with gr.Row():
with gr.Column(scale=1):
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
layout["training"]["buttons"]["train"].click(
fn=do_training,
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
)
"""
if os.path.exists("README.md") and args.render_markdown: if os.path.exists("README.md") and args.render_markdown:
md = open("README.md", "r", encoding="utf-8").read() md = open("README.md", "r", encoding="utf-8").read()
# remove HF's metadata # remove HF's metadata
@ -232,5 +258,9 @@ with ui:
md = "".join(md.split("---")[2:]) md = "".join(md.split("---")[2:])
gr.Markdown(md) gr.Markdown(md)
ui.queue(max_size=8) def start( lock=True ):
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port) ui.queue(max_size=8)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
if __name__ == "__main__":
start()