fixed issue with training from scratch (oops)
This commit is contained in:
parent
3195026dba
commit
32d4271ca8
|
@ -482,6 +482,8 @@ class Trainer:
|
||||||
weight_dtype: str = "float16"
|
weight_dtype: str = "float16"
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
|
|
||||||
|
load_webui: bool = False
|
||||||
|
|
||||||
backend: str = "local"
|
backend: str = "local"
|
||||||
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
deepspeed: DeepSpeed = field(default_factory=lambda: DeepSpeed)
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,9 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
@cache
|
||||||
def load_engines():
|
def load_engines():
|
||||||
models = get_models(cfg.models.get())
|
models = get_models(cfg.models.get())
|
||||||
engines = dict()
|
engines = dict()
|
||||||
|
@ -71,13 +74,13 @@ def load_engines():
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||||
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
||||||
print("DeepSpeed checkpoint missing, but weights found.")
|
print("DeepSpeed checkpoint missing, but weights found.")
|
||||||
loads_state_dict = True
|
loads_state_dict = True
|
||||||
|
|
||||||
stats = None
|
stats = None
|
||||||
if loads_state_dict:
|
if loads_state_dict:
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
|
||||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||||
|
|
||||||
# state dict is not just the module, extract the extra trainer details
|
# state dict is not just the module, extract the extra trainer details
|
||||||
|
|
|
@ -147,7 +147,7 @@ def run_eval(engines, eval_name, dl):
|
||||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def train():
|
||||||
setup_logging(cfg.log_dir)
|
setup_logging(cfg.log_dir)
|
||||||
|
|
||||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||||
|
@ -165,6 +165,12 @@ def main():
|
||||||
|
|
||||||
qnt.unload_model()
|
qnt.unload_model()
|
||||||
|
|
||||||
|
"""
|
||||||
|
if cfg.trainer.load_webui:
|
||||||
|
from .webui import start
|
||||||
|
start(lock=False)
|
||||||
|
"""
|
||||||
|
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
train_feeder=train_feeder,
|
train_feeder=train_feeder,
|
||||||
|
@ -172,4 +178,4 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
train()
|
||||||
|
|
|
@ -173,7 +173,8 @@ def train(
|
||||||
|
|
||||||
|
|
||||||
elapsed_time = stats.get("elapsed_time", 0)
|
elapsed_time = stats.get("elapsed_time", 0)
|
||||||
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
metrics = json.dumps(stats)
|
||||||
|
_logger.info(f"Training Metrics: {metrics}.")
|
||||||
|
|
||||||
command = _non_blocking_input()
|
command = _non_blocking_input()
|
||||||
|
|
||||||
|
|
|
@ -12,16 +12,18 @@ from time import perf_counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .inference import TTS
|
from .inference import TTS
|
||||||
|
from .train import train
|
||||||
|
|
||||||
tts = None
|
tts = None
|
||||||
|
|
||||||
layout = {}
|
layout = {}
|
||||||
layout["inference"] = {}
|
layout["inference"] = {}
|
||||||
layout["inference"]["inputs"] = {
|
layout["training"] = {}
|
||||||
"progress": None
|
|
||||||
}
|
for k in layout.keys():
|
||||||
layout["inference"]["outputs"] = {}
|
layout[k]["inputs"] = { "progress": None }
|
||||||
layout["inference"]["buttons"] = {}
|
layout[k]["outputs"] = {}
|
||||||
|
layout[k]["buttons"] = {}
|
||||||
|
|
||||||
# there's got to be a better way to go about this
|
# there's got to be a better way to go about this
|
||||||
def gradio_wrapper(inputs):
|
def gradio_wrapper(inputs):
|
||||||
|
@ -123,6 +125,14 @@ def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
wav = wav.squeeze(0).cpu().numpy()
|
wav = wav.squeeze(0).cpu().numpy()
|
||||||
return (sr, wav)
|
return (sr, wav)
|
||||||
|
|
||||||
|
"""
|
||||||
|
@gradio_wrapper(inputs=layout["training"]["inputs"].keys())
|
||||||
|
def do_training( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
|
while True:
|
||||||
|
metrics = next(it)
|
||||||
|
yield metrics
|
||||||
|
"""
|
||||||
|
|
||||||
def get_random_prompt():
|
def get_random_prompt():
|
||||||
harvard_sentences=[
|
harvard_sentences=[
|
||||||
"The birch canoe slid on the smooth planks.",
|
"The birch canoe slid on the smooth planks.",
|
||||||
|
@ -225,6 +235,22 @@ with ui:
|
||||||
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
|
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
|
||||||
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
|
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
with gr.Tab("Training"):
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
|
||||||
|
|
||||||
|
layout["training"]["buttons"]["train"].click(
|
||||||
|
fn=do_training,
|
||||||
|
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
if os.path.exists("README.md") and args.render_markdown:
|
if os.path.exists("README.md") and args.render_markdown:
|
||||||
md = open("README.md", "r", encoding="utf-8").read()
|
md = open("README.md", "r", encoding="utf-8").read()
|
||||||
# remove HF's metadata
|
# remove HF's metadata
|
||||||
|
@ -232,5 +258,9 @@ with ui:
|
||||||
md = "".join(md.split("---")[2:])
|
md = "".join(md.split("---")[2:])
|
||||||
gr.Markdown(md)
|
gr.Markdown(md)
|
||||||
|
|
||||||
ui.queue(max_size=8)
|
def start( lock=True ):
|
||||||
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port)
|
ui.queue(max_size=8)
|
||||||
|
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start()
|
Loading…
Reference in New Issue
Block a user