added ability to either specify a raw path or a base64 encoded string of an image
This commit is contained in:
parent
ba2ca9c24d
commit
77a9625e93
|
@ -1,7 +1,13 @@
|
|||
import argparse
|
||||
import base64
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from .inference import CAPTCHA
|
||||
|
||||
from PIL import Image
|
||||
from simple_http_server import route, server
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
|
||||
parser.add_argument("--listen", action='store_true')
|
||||
|
@ -14,18 +20,26 @@ def main():
|
|||
|
||||
captcha = CAPTCHA( config=args.yaml, ckpt=args.ckpt, device=args.device )
|
||||
if args.listen:
|
||||
from simple_http_server import route, server
|
||||
|
||||
@route("/")
|
||||
def inference( path, temperature=1.0 ):
|
||||
return { "answer": captcha.inference( path=Path(path), temperature=args.temp ) }
|
||||
def inference( b64, temperature=1.0 ):
|
||||
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
|
||||
return { "answer": captcha.inference( image=image, temperature=args.temp ) }
|
||||
server.start(port=args.port)
|
||||
else:
|
||||
parser = argparse.ArgumentParser("CAPTCHA", allow_abbrev=False)
|
||||
parser.add_argument("path", type=Path)
|
||||
args2, unknown = parser.parse_known_args()
|
||||
parser.add_argument("--path", type=Path)
|
||||
parser.add_argument("--base64", type=str)
|
||||
parser.add_argument("--temp", type=float, default=1.0)
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
answer = captcha.inference( path=args2.path, temperature=args.temp )
|
||||
if args.path:
|
||||
image = Image.open(args.path).convert('RGB')
|
||||
elif args.base64:
|
||||
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
|
||||
else:
|
||||
raise "Specify a --path or --base64."
|
||||
|
||||
answer = captcha.inference( image=image, temperature=args.temp )
|
||||
print("Answer:", answer)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,21 +1,24 @@
|
|||
import torch
|
||||
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from .config import cfg
|
||||
from .export import load_models
|
||||
from .data import get_symmap, _get_symbols
|
||||
|
||||
class CAPTCHA():
|
||||
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
|
||||
self.loading = True
|
||||
self.device = device
|
||||
|
||||
if config:
|
||||
cfg.load_yaml( config )
|
||||
|
||||
if ckpt:
|
||||
self.load_model_from_ckpt( ckpt )
|
||||
else:
|
||||
self.load_model_from_cfg( config )
|
||||
|
||||
|
||||
self.model.eval()
|
||||
|
||||
self.width = width
|
||||
|
@ -34,8 +37,6 @@ class CAPTCHA():
|
|||
self.model = torch.load(self.ckpt).to(self.device)
|
||||
|
||||
def load_model_from_cfg( self, config_path ):
|
||||
if config_path:
|
||||
cfg.load_yaml( config_path )
|
||||
|
||||
models = load_models()
|
||||
for name in models:
|
||||
|
@ -43,8 +44,9 @@ class CAPTCHA():
|
|||
self.model = model.to(self.device)
|
||||
break
|
||||
|
||||
def inference( self, path, temperature=1.0 ):
|
||||
image = self.transform(Image.open(path).convert('RGB')).to(self.device)
|
||||
def inference( self, image, temperature=1.0 ):
|
||||
image = self.transform(image).to(self.device)
|
||||
|
||||
answer = self.model( image=[image], sampling_temperature=temperature )
|
||||
answer = answer[0].replace('<s>', "").replace("</s>", "")
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user