From ba2ca9c24df3fcce20689296276f5832fe0cde5b Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 5 Aug 2023 16:50:53 +0000 Subject: [PATCH] added small listen server to allow inferencing (todo: allow reading from base64) --- image_classifier/__main__.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/image_classifier/__main__.py b/image_classifier/__main__.py index 9ae5741..b6dfcc9 100755 --- a/image_classifier/__main__.py +++ b/image_classifier/__main__.py @@ -3,17 +3,30 @@ from pathlib import Path from .inference import CAPTCHA def main(): - parser = argparse.ArgumentParser("CAPTCHA") - parser.add_argument("path", type=Path) + parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) + parser.add_argument("--listen", action='store_true') + parser.add_argument("--port", type=int, default=9090) parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--ckpt", type=Path, default=None) parser.add_argument("--temp", type=float, default=1.0) parser.add_argument("--device", default="cuda") - args = parser.parse_args() + args, unknown = parser.parse_known_args() captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device ) - answer = captcha.inference( path=args.path, temperature=args.temp ) - print("Answer:", answer) + if args.listen: + from simple_http_server import route, server + + @route("/") + def inference( path, temperature=1.0 ): + return { "answer": captcha.inference( path=Path(path), temperature=args.temp ) } + server.start(port=args.port) + else: + parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) + parser.add_argument("path", type=Path) + args2, unknown = parser.parse_known_args() + + answer = captcha.inference( path=args2.path, temperature=args.temp ) + print("Answer:", answer) if __name__ == "__main__": main()