resnet-classifier/image_classifier/webui.py

220 lines
6.8 KiB
Python

import os
import re
import argparse
import random
import tempfile
import functools
from datetime import datetime
import gradio as gr
from time import perf_counter
from pathlib import Path
from PIL import Image
from .inference import Classifier, cfg
from .train import train
from .utils import get_devices
classifier = None
layout = {}
layout["inference"] = {}
layout["training"] = {}
layout["settings"] = {}
for k in layout.keys():
layout[k]["inputs"] = { "progress": None }
layout[k]["outputs"] = {}
layout[k]["buttons"] = {}
# there's got to be a better way to go about this
def gradio_wrapper(inputs):
def decorated(fun):
@functools.wraps(fun)
def wrapped_function(*args, **kwargs):
for i, key in enumerate(inputs):
kwargs[key] = args[i]
try:
return fun(**kwargs)
except Exception as e:
raise gr.Error(str(e))
return wrapped_function
return decorated
class timer:
def __init__(self, msg="Elapsed time:"):
self.msg = msg
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
gr.Info(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
# returns a list of models, assuming the models are placed under ./training/ or ./models/
def get_model_paths( paths=[Path("./data/"), Path("./training/"), Path("./models/")] ):
yamls = []
for path in paths:
if not path.exists():
continue
for yaml in path.glob("**/*.yaml"):
if "/logs/" in str(yaml):
continue
yamls.append( yaml )
return yamls
def get_dtypes():
return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"]
#@gradio_wrapper(inputs=layout["settings"]["inputs"].keys())
def load_model( yaml, device, dtype ):
gr.Info(f"Loading: {yaml}")
try:
init_classifier( yaml=Path(yaml), restart=True, device=device, dtype=dtype )
except Exception as e:
raise gr.Error(e)
gr.Info(f"Loaded model")
def init_classifier(yaml=None, restart=False, device="cuda", dtype="auto"):
global classifier
if classifier is not None:
if not restart:
return classifier
del classifier
classifier = None
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--device", type=str, default=device)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default=dtype)
args, unknown = parser.parse_known_args()
classifier = Classifier( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return classifier
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())
def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
if not cfg.yaml_path:
raise Exception("No YAML loaded.")
parser = argparse.ArgumentParser(allow_abbrev=False)
# I'm very sure I can procedurally generate this list
parser.add_argument("--image", type=str, default=kwargs["image"])
parser.add_argument("--temp", type=float, default=kwargs["temp"])
args, unknown = parser.parse_known_args()
classifier = init_classifier()
args.image = Image.open(args.image).convert('RGB')
gr.Info("Inferencing...")
with timer("Inferenced in") as t:
answer = classifier.inference(
image=args.image,
temperature=args.temp,
)
return answer
# setup args
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--listen", default=None, help="Path for Gradio to listen on")
parser.add_argument("--share", action="store_true")
parser.add_argument("--render_markdown", action="store_true", default="CLASSIFIER_YAML" in os.environ)
args, unknown = parser.parse_known_args()
args.listen_host = None
args.listen_port = None
args.listen_path = None
if args.listen:
try:
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
args.listen_port = match[1] if match[1] != "" else None
args.listen_path = match[2] if match[2] != "" else "/"
except Exception as e:
pass
if args.listen_port is not None:
args.listen_port = int(args.listen_port)
if args.listen_port == 0:
args.listen_port = None
# setup gradio
ui = gr.Blocks()
with ui:
with gr.Tab("Inference"):
with gr.Row():
with gr.Column(scale=4):
layout["inference"]["inputs"]["image"] = gr.Image(label="Input Image", sources=["upload"], type="filepath")
layout["inference"]["outputs"]["output"] = gr.Textbox(label="Output")
with gr.Column(scale=4):
with gr.Row():
layout["inference"]["inputs"]["temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature", info="Modifies the randomness from the samples. (0 to greedy sample)")
layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference")
layout["inference"]["buttons"]["inference"].click(
fn=do_inference,
inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None]
)
"""
with gr.Tab("Training"):
with gr.Row():
with gr.Column(scale=1):
layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log")
with gr.Row():
with gr.Column(scale=1):
layout["training"]["buttons"]["train"] = gr.Button(value="Train")
layout["training"]["buttons"]["train"].click(
fn=do_training,
outputs=[ x for x in layout["training"]["outputs"].values() if x is not None],
)
"""
with gr.Tab("Settings"):
with gr.Row():
with gr.Column(scale=7):
with gr.Row():
layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model")
layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device")
layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision")
with gr.Column(scale=1):
layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model")
layout["settings"]["buttons"]["load"].click(
fn=load_model,
inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None],
outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None],
)
if os.path.exists("README.md") and args.render_markdown:
md = open("README.md", "r", encoding="utf-8").read()
# remove HF's metadata
if md.startswith("---\n"):
md = "".join(md.split("---")[2:])
gr.Markdown(md)
def start( lock=True ):
ui.queue(max_size=8)
ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock)
if __name__ == "__main__":
start()