added ability to either specify a raw path or a base64 encoded string of an image

This commit is contained in:
mrq 2023-08-05 18:14:28 +00:00
parent ba2ca9c24d
commit 77a9625e93
2 changed files with 29 additions and 13 deletions

View File

@ -1,7 +1,13 @@
import argparse import argparse
import base64
from io import BytesIO
from pathlib import Path from pathlib import Path
from .inference import CAPTCHA from .inference import CAPTCHA
from PIL import Image
from simple_http_server import route, server
def main(): def main():
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
parser.add_argument("--listen", action='store_true') parser.add_argument("--listen", action='store_true')
@ -14,18 +20,26 @@ def main():
captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device ) captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device )
if args.listen: if args.listen:
from simple_http_server import route, server
@route("/") @route("/")
def inference( path, temperature=1.0 ): def inference( b64, temperature=1.0 ):
return { "answer": captcha.inference( path=Path(path), temperature=args.temp ) } image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
return { "answer": captcha.inference( image=image, temperature=args.temp ) }
server.start(port=args.port) server.start(port=args.port)
else: else:
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False) parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
parser.add_argument("path", type=Path) parser.add_argument("--path", type=Path)
args2, unknown = parser.parse_known_args() parser.add_argument("--base64", type=str)
parser.add_argument("--temp", type=float, default=1.0)
args, unknown = parser.parse_known_args()
answer = captcha.inference( path=args2.path, temperature=args.temp ) if args.path:
image = Image.open(args.path).convert('RGB')
elif args.base64:
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
else:
raise "Specify a --path or --base64."
answer = captcha.inference( image=image, temperature=args.temp )
print("Answer:", answer) print("Answer:", answer)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,21 +1,24 @@
import torch import torch
from PIL import Image
import torchvision.transforms as transforms import torchvision.transforms as transforms
from .config import cfg from .config import cfg
from .export import load_models from .export import load_models
from .data import get_symmap, _get_symbols
class CAPTCHA(): class CAPTCHA():
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ): def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
self.loading = True self.loading = True
self.device = device self.device = device
if config:
cfg.load_yaml( config )
if ckpt: if ckpt:
self.load_model_from_ckpt( ckpt ) self.load_model_from_ckpt( ckpt )
else: else:
self.load_model_from_cfg( config ) self.load_model_from_cfg( config )
self.model.eval() self.model.eval()
self.width = width self.width = width
@ -34,8 +37,6 @@ class CAPTCHA():
self.model = torch.load(self.ckpt).to(self.device) self.model = torch.load(self.ckpt).to(self.device)
def load_model_from_cfg( self, config_path ): def load_model_from_cfg( self, config_path ):
if config_path:
cfg.load_yaml( config_path )
models = load_models() models = load_models()
for name in models: for name in models:
@ -43,8 +44,9 @@ class CAPTCHA():
self.model = model.to(self.device) self.model = model.to(self.device)
break break
def inference( self, path, temperature=1.0 ): def inference( self, image, temperature=1.0 ):
image = self.transform(Image.open(path).convert('RGB')).to(self.device) image = self.transform(image).to(self.device)
answer = self.model( image=[image], sampling_temperature=temperature ) answer = self.model( image=[image], sampling_temperature=temperature )
answer = answer[0].replace('<s>', "").replace("</s>", "") answer = answer[0].replace('<s>', "").replace("</s>", "")