fixed issue with training from scratch (oops)
This commit is contained in:
parent
3195026dba
commit
32d4271ca8
|
@ -482,6 +482,8 @@ class Trainer:
|
|||
weight_dtype: str = "float16"
|
||||
amp: bool = False
|
||||
|
||||
load_webui: bool = False
|
||||
|
||||
backend: str = "local"
|
||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||
|
||||
|
|
|
@ -21,6 +21,9 @@ try:
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
from functools import cache
|
||||
|
||||
@cache
|
||||
def load_engines():
|
||||
models = get_models(cfg.models.get())
|
||||
engines = dict()
|
||||
|
@ -71,13 +74,13 @@ def load_engines():
|
|||
lr_scheduler = None
|
||||
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
||||
print("DeepSpeed checkpoint missing, but weights found.")
|
||||
loads_state_dict = True
|
||||
|
||||
stats = None
|
||||
if loads_state_dict:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
# state dict is not just the module, extract the extra trainer details
|
||||
|
|
|
@ -147,7 +147,7 @@ def run_eval(engines, eval_name, dl):
|
|||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
||||
|
||||
def main():
|
||||
def train():
|
||||
setup_logging(cfg.log_dir)
|
||||
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
@ -165,6 +165,12 @@ def main():
|
|||
|
||||
qnt.unload_model()
|
||||
|
||||
"""
|
||||
if cfg.trainer.load_webui:
|
||||
from .webui import start
|
||||
start(lock=False)
|
||||
"""
|
||||
|
||||
trainer.train(
|
||||
train_dl=train_dl,
|
||||
train_feeder=train_feeder,
|
||||
|
@ -172,4 +178,4 @@ def main():
|
|||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
train()
|
||||
|
|
|
@ -173,7 +173,8 @@ def train(
|
|||
|
||||
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
||||
metrics = json.dumps(stats)
|
||||
_logger.info(f"Training Metrics: {metrics}.")
|
||||
|
||||
command = _non_blocking_input()
|
||||
|
||||
|
|
|
@ -12,16 +12,18 @@ from time import perf_counter
|
|||
from pathlib import Path
|
||||
|
||||
from .inference import TTS
|
||||
from .train import train
|
||||
|
||||
tts = None
|
||||
|
||||
layout = {}
|
||||
layout["inference"] = {}
|
||||
layout["inference"]["inputs"] = {
|
||||
"progress": None
|
||||
}
|
||||
layout["inference"]["outputs"] = {}
|
||||
layout["inference"]["buttons"] = {}
|
||||
layout["training"] = {}
|
||||
|
||||
for k in layout.keys():
|
||||
layout[k]["inputs"] = { "progress": None }
|
||||
layout[k]["outputs"] = {}
|
||||
layout[k]["buttons"] = {}
|
||||
|
||||
# there's got to be a better way to go about this
|
||||
def gradio_wrapper(inputs):
|
||||
|
@ -123,6 +125,14 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
wav = wav.squeeze(0).cpu().numpy()
|
||||
return (sr, wav)
|
||||
|
||||
"""
|
||||
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
|
||||
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||
while True:
|
||||
metrics = next(it)
|
||||
yield metrics
|
||||
"""
|
||||
|
||||
def get_random_prompt():
|
||||
harvard_sentences=[
|
||||
"The birch canoe slid on the smooth planks.",
|
||||
|
@ -225,6 +235,22 @@ with ui:
|
|||
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
|
||||
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
|
||||
)
|
||||
|
||||
"""
|
||||
with gr.Tab("Training"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
|
||||
|
||||
layout["training"]["buttons"]["train"].click(
|
||||
fn=do_training,
|
||||
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
|
||||
)
|
||||
"""
|
||||
|
||||
if os.path.exists("README.md") and args.render_markdown:
|
||||
md = open("README.md", "r", encoding="utf-8").read()
|
||||
# remove HF's metadata
|
||||
|
@ -232,5 +258,9 @@ with ui:
|
|||
md = "".join(md.split("---")[2:])
|
||||
gr.Markdown(md)
|
||||
|
||||
def start( lock=True ):
|
||||
ui.queue(max_size=8)
|
||||
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port)
|
||||
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
|
||||
|
||||
if __name__ == "__main__":
|
||||
start()
|
Loading…
Reference in New Issue
Block a user