import os import re import argparse import random import tempfile import functools from datetime import datetime import gradio as gr from time import perf_counter from pathlib import Path from PIL import Image from .inference import Classifier, cfg from .train import train from .utils import get_devices classifier = None layout = {} layout["inference"] = {} layout["training"] = {} layout["settings"] = {} for k in layout.keys(): layout[k]["inputs"] = { "progress": None } layout[k]["outputs"] = {} layout[k]["buttons"] = {} # there's got to be a better way to go about this def gradio_wrapper(inputs): def decorated(fun): @functools.wraps(fun) def wrapped_function(*args, **kwargs): for i, key in enumerate(inputs): kwargs[key] = args[i] try: return fun(**kwargs) except Exception as e: raise gr.Error(str(e)) return wrapped_function return decorated class timer: def __init__(self, msg="Elapsed time:"): self.msg = msg def __enter__(self): self.start = perf_counter() return self def __exit__(self, type, value, traceback): msg = f'{self.msg} {(perf_counter() - self.start):.3f}s' gr.Info(msg) print(f'[{datetime.now().isoformat()}] {msg}') # returns a list of models, assuming the models are placed under ./training/ or ./models/ def get_model_paths( paths=[Path("./data/"), Path("./training/"), Path("./models/")] ): yamls = [] for path in paths: if not path.exists(): continue for yaml in path.glob("**/*.yaml"): if "/logs/" in str(yaml): continue yamls.append( yaml ) return yamls def get_dtypes(): return ["float32", "float16", "bfloat16", "float8_e5m2", "float8_e4m3fn", "auto"] #@gradio_wrapper(inputs=layout["settings"]["inputs"].keys()) def load_model( yaml, device, dtype ): gr.Info(f"Loading: {yaml}") try: init_classifier( yaml=Path(yaml), restart=True, device=device, dtype=dtype ) except Exception as e: raise gr.Error(e) gr.Info(f"Loaded model") def init_classifier(yaml=None, restart=False, device="cuda", dtype="auto"): global classifier if classifier is not None: if not restart: return classifier del classifier classifier = None parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', yaml)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--device", type=str, default=device) parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=dtype) args, unknown = parser.parse_known_args() classifier = Classifier( config=args.yaml if yaml is None else yaml, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) return classifier @gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) def do_inference( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): if not cfg.yaml_path: raise Exception("No YAML loaded.") parser = argparse.ArgumentParser(allow_abbrev=False) # I'm very sure I can procedurally generate this list parser.add_argument("--image", type=str, default=kwargs["image"]) parser.add_argument("--temp", type=float, default=kwargs["temp"]) args, unknown = parser.parse_known_args() classifier = init_classifier() args.image = Image.open(args.image).convert('RGB') gr.Info("Inferencing...") with timer("Inferenced in") as t: answer = classifier.inference( image=args.image, temperature=args.temp, ) return answer # setup args parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('CLASSIFIER_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--listen", default=None, help="Path for Gradio to listen on") parser.add_argument("--share", action="store_true") parser.add_argument("--render_markdown", action="store_true", default="CLASSIFIER_YAML" in os.environ) args, unknown = parser.parse_known_args() args.listen_host = None args.listen_port = None args.listen_path = None if args.listen: try: match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0] args.listen_host = match[0] if match[0] != "" else "127.0.0.1" args.listen_port = match[1] if match[1] != "" else None args.listen_path = match[2] if match[2] != "" else "/" except Exception as e: pass if args.listen_port is not None: args.listen_port = int(args.listen_port) if args.listen_port == 0: args.listen_port = None # setup gradio ui = gr.Blocks() with ui: with gr.Tab("Inference"): with gr.Row(): with gr.Column(scale=4): layout["inference"]["inputs"]["image"] = gr.Image(label="Input Image", sources=["upload"], type="filepath") layout["inference"]["outputs"]["output"] = gr.Textbox(label="Output") with gr.Column(scale=4): with gr.Row(): layout["inference"]["inputs"]["temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature", info="Modifies the randomness from the samples. (0 to greedy sample)") layout["inference"]["buttons"]["inference"] = gr.Button(value="Inference") layout["inference"]["buttons"]["inference"].click( fn=do_inference, inputs=[ x for x in layout["inference"]["inputs"].values() if x is not None], outputs=[ x for x in layout["inference"]["outputs"].values() if x is not None] ) """ with gr.Tab("Training"): with gr.Row(): with gr.Column(scale=1): layout["training"]["outputs"]["console"] = gr.Textbox(lines=8, label="Console Log") with gr.Row(): with gr.Column(scale=1): layout["training"]["buttons"]["train"] = gr.Button(value="Train") layout["training"]["buttons"]["train"].click( fn=do_training, outputs=[ x for x in layout["training"]["outputs"].values() if x is not None], ) """ with gr.Tab("Settings"): with gr.Row(): with gr.Column(scale=7): with gr.Row(): layout["settings"]["inputs"]["models"] = gr.Dropdown(choices=get_model_paths(), value=args.yaml, label="Model") layout["settings"]["inputs"]["device"] = gr.Dropdown(choices=get_devices(), value="cuda:0", label="Device") layout["settings"]["inputs"]["dtype"] = gr.Dropdown(choices=get_dtypes(), value="auto", label="Precision") with gr.Column(scale=1): layout["settings"]["buttons"]["load"] = gr.Button(value="Load Model") layout["settings"]["buttons"]["load"].click( fn=load_model, inputs=[ x for x in layout["settings"]["inputs"].values() if x is not None], outputs=[ x for x in layout["settings"]["outputs"].values() if x is not None], ) if os.path.exists("README.md") and args.render_markdown: md = open("README.md", "r", encoding="utf-8").read() # remove HF's metadata if md.startswith("---\n"): md = "".join(md.split("---")[2:]) gr.Markdown(md) def start( lock=True ): ui.queue(max_size=8) ui.launch(share=args.share, server_name=args.listen_host, server_port=args.listen_port, prevent_thread_lock=not lock) if __name__ == "__main__": start()