220 lines
6.8 KiB
Python
220 lines
6.8 KiB
Python
import os
|
|
import re
|
|
import argparse
|
|
import random
|
|
import tempfile
|
|
import functools
|
|
from datetime import datetime
|
|
|
|
import gradio as gr
|
|
|
|
from time import perf_counter
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
|
|
from .inference import Classifier, cfg
|
|
from .train import train
|
|
from .utils import get_devices
|
|
|
|
classifier = None
|
|
|
|
layout = {}
|
|
layout["inference"] = {}
|
|
layout["training"] = {}
|
|
layout["settings"] = {}
|
|
|
|
for k in layout.keys():
|
|
layout[k]["inputs"] = { "progress": None }
|
|
layout[k]["outputs"] = {}
|
|
layout[k]["buttons"] = {}
|
|
|
|
# there's got to be a better way to go about this
|
|
def gradio_wrapper(inputs):
|
|
def decorated(fun):
|
|
@functools.wraps(fun)
|
|
def wrapped_function(*args, **kwargs):
|
|
for i, key in enumerate(inputs):
|
|
kwargs[key] = args[i]
|
|
try:
|
|
return fun(**kwargs)
|
|
except Exception as e:
|
|
raise gr.Error(str(e))
|
|
return wrapped_function
|
|
return decorated
|
|
|
|
class timer:
|
|
def __init__(self, msg="Elapsed time:"):
|
|
self.msg = msg
|
|
|
|
def __enter__(self):
|
|
self.start = perf_counter()
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
|
|
|
|
gr.Info(msg)
|
|
print(f'[{datetime.now().isoformat()}] {msg}')
|
|
|
|
# returns a list of models, assuming the models are placed under ./training/ or ./models/
|
|
def get_model_paths( paths=[Path("./data/"), Path("./training/"), Path("./models/")] ):
|
|
yamls = []
|
|
|
|
for path in paths:
|
|
if not path.exists():
|
|
continue
|
|
|
|
for yaml in path.glob("**/*.yaml"):
|
|
if "/logs/" in str(yaml):
|
|
continue
|
|
|
|
yamls.append( yaml )
|
|
|
|
return yamls
|
|
|
|
def get_dtypes():
|
|
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
|
|
|
|
|
|
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
|
|
def load_model( yaml, device, dtype ):
|
|
gr.Info(f"Loading: {yaml}")
|
|
try:
|
|
init_classifier( yaml=Path(yaml), restart=True, device=device, dtype=dtype )
|
|
except Exception as e:
|
|
raise gr.Error(e)
|
|
gr.Info(f"Loaded model")
|
|
|
|
def init_classifier(yaml=None, restart=False, device="cuda", dtype="auto"):
|
|
global classifier
|
|
|
|
if classifier is not None:
|
|
if not restart:
|
|
return classifier
|
|
|
|
del classifier
|
|
classifier = None
|
|
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
|
|
parser.add_argument("--device", type=str, default=device)
|
|
parser.add_argument("--amp", action="store_true")
|
|
parser.add_argument("--dtype", type=str, default=dtype)
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
classifier = Classifier( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
|
|
return classifier
|
|
|
|
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
|
|
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|
if not cfg.yaml_path:
|
|
raise Exception("No YAML loaded.")
|
|
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
# I'm very sure I can procedurally generate this list
|
|
parser.add_argument("--image", type=str, default=kwargs["image"])
|
|
parser.add_argument("--temp", type=float, default=kwargs["temp"])
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
classifier = init_classifier()
|
|
|
|
args.image = Image.open(args.image).convert('RGB')
|
|
|
|
gr.Info("Inferencing...")
|
|
with timer("Inferenced in") as t:
|
|
answer = classifier.inference(
|
|
image=args.image,
|
|
temperature=args.temp,
|
|
)
|
|
|
|
return answer
|
|
|
|
# setup args
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
|
|
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
|
|
parser.add_argument("--share", action="store_true")
|
|
parser.add_argument("--render_markdown", action="store_true", default="CLASSIFIER_YAML" in os.environ)
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
args.listen_host = None
|
|
args.listen_port = None
|
|
args.listen_path = None
|
|
if args.listen:
|
|
try:
|
|
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
|
|
|
|
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
|
|
args.listen_port = match[1] if match[1] != "" else None
|
|
args.listen_path = match[2] if match[2] != "" else "/"
|
|
except Exception as e:
|
|
pass
|
|
|
|
if args.listen_port is not None:
|
|
args.listen_port = int(args.listen_port)
|
|
if args.listen_port == 0:
|
|
args.listen_port = None
|
|
|
|
# setup gradio
|
|
ui = gr.Blocks()
|
|
with ui:
|
|
with gr.Tab("Inference"):
|
|
with gr.Row():
|
|
with gr.Column(scale=4):
|
|
layout["inference"]["inputs"]["image"] = gr.Image(label="Input Image", sources=["upload"], type="filepath")
|
|
layout["inference"]["outputs"]["output"] = gr.Textbox(label="Output")
|
|
with gr.Column(scale=4):
|
|
with gr.Row():
|
|
layout["inference"]["inputs"]["temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature", info="Modifies the randomness from the samples. (0 to greedy sample)")
|
|
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
|
|
|
|
layout["inference"]["buttons"]["inference"].click(
|
|
fn=do_inference,
|
|
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
|
|
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
|
|
)
|
|
|
|
"""
|
|
with gr.Tab("Training"):
|
|
with gr.Row():
|
|
with gr.Column(scale=1):
|
|
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
|
|
with gr.Row():
|
|
with gr.Column(scale=1):
|
|
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
|
|
|
|
layout["training"]["buttons"]["train"].click(
|
|
fn=do_training,
|
|
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
|
|
)
|
|
"""
|
|
|
|
with gr.Tab("Settings"):
|
|
with gr.Row():
|
|
with gr.Column(scale=7):
|
|
with gr.Row():
|
|
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
|
|
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
|
|
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
|
|
with gr.Column(scale=1):
|
|
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
|
|
|
|
layout["settings"]["buttons"]["load"].click(
|
|
fn=load_model,
|
|
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
|
|
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
|
|
)
|
|
|
|
if os.path.exists("README.md") and args.render_markdown:
|
|
md = open("README.md", "r", encoding="utf-8").read()
|
|
# remove HF's metadata
|
|
if md.startswith("---\n"):
|
|
md = "".join(md.split("---")[2:])
|
|
gr.Markdown(md)
|
|
|
|
def start( lock=True ):
|
|
ui.queue(max_size=8)
|
|
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
|
|
|
|
if __name__ == "__main__":
|
|
start() |