From 77a9625e933fb65130a877be1e96785b126a6f2f Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 5 Aug 2023 18:14:28 +0000 Subject: [PATCH] added ability to either specify a raw path or a base64 encoded string of an image --- image_classifier/__main__.py | 28 +++++++++++++++++++++------- image_classifier/inference.py | 14 ++++++++------ 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/image_classifier/__main__.py b/image_classifier/__main__.py index b6dfcc9..6f7019b 100755 --- a/image_classifier/__main__.py +++ b/image_classifier/__main__.py @@ -1,7 +1,13 @@ import argparse +import base64 + +from io import BytesIO from pathlib import Path from .inference import CAPTCHA +from PIL import Image +from simple_http_server import route, server + def main(): parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) parser.add_argument("--listen", action='store_true') @@ -14,18 +20,26 @@ def main(): captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device ) if args.listen: - from simple_http_server import route, server - @route("/") - def inference( path, temperature=1.0 ): - return { "answer": captcha.inference( path=Path(path), temperature=args.temp ) } + def inference( b64, temperature=1.0 ): + image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") + return { "answer": captcha.inference( image=image, temperature=args.temp ) } server.start(port=args.port) else: parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) - parser.add_argument("path", type=Path) - args2, unknown = parser.parse_known_args() + parser.add_argument("--path", type=Path) + parser.add_argument("--base64", type=str) + parser.add_argument("--temp", type=float, default=1.0) + args, unknown = parser.parse_known_args() - answer = captcha.inference( path=args2.path, temperature=args.temp ) + if args.path: + image = Image.open(args.path).convert('RGB') + elif args.base64: + image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB") + else: + raise "Specify a --path or --base64." + + answer = captcha.inference( image=image, temperature=args.temp ) print("Answer:", answer) if __name__ == "__main__": diff --git a/image_classifier/inference.py b/image_classifier/inference.py index 1e5e318..43e4824 100755 --- a/image_classifier/inference.py +++ b/image_classifier/inference.py @@ -1,21 +1,24 @@ import torch -from PIL import Image import torchvision.transforms as transforms from .config import cfg from .export import load_models +from .data import get_symmap, _get_symbols class CAPTCHA(): def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ): self.loading = True self.device = device + if config: + cfg.load_yaml( config ) + if ckpt: self.load_model_from_ckpt( ckpt ) else: self.load_model_from_cfg( config ) - + self.model.eval() self.width = width @@ -34,8 +37,6 @@ class CAPTCHA(): self.model = torch.load(self.ckpt).to(self.device) def load_model_from_cfg( self, config_path ): - if config_path: - cfg.load_yaml( config_path ) models = load_models() for name in models: @@ -43,8 +44,9 @@ class CAPTCHA(): self.model = model.to(self.device) break - def inference( self, path, temperature=1.0 ): - image = self.transform(Image.open(path).convert('RGB')).to(self.device) + def inference( self, image, temperature=1.0 ): + image = self.transform(image).to(self.device) + answer = self.model( image=[image], sampling_temperature=temperature ) answer = answer[0].replace('', "").replace("", "")