resnet-classifier/image_classifier/__main__.py

47 lines
1.6 KiB
Python
Raw Normal View History

2023-08-05 03:40:14 +00:00
import argparse
import base64
from io import BytesIO
2023-08-05 03:40:14 +00:00
from pathlib import Path
from .inference import Classifier
2023-08-05 03:40:14 +00:00
from PIL import Image
from simple_http_server import route, server
2023-08-05 03:40:14 +00:00
def main():
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--listen", action='store_true')
parser.add_argument("--port", type=int, default=9090)
2023-08-05 03:40:14 +00:00
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ckpt", type=Path, default=None)
parser.add_argument("--temp", type=float, default=1.0)
parser.add_argument("--device", default="cuda")
args, unknown = parser.parse_known_args()
2023-08-05 03:40:14 +00:00
classifier = Classifier( config=args.yaml, ckpt=args.ckpt, device=args.device )
if args.listen:
@route("/")
def inference( b64, temperature=1.0 ):
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
return { "answer": classifier.inference( image=image, temperature=args.temp ) }
server.start(port=args.port)
else:
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--path", type=Path)
parser.add_argument("--base64", type=str)
parser.add_argument("--temp", type=float, default=1.0)
args, unknown = parser.parse_known_args()
if args.path:
image = Image.open(args.path).convert('RGB')
elif args.base64:
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
else:
raise "Specify a --path or --base64."
answer = classifier.inference( image=image, temperature=args.temp )
print("Answer:", answer)
2023-08-05 03:40:14 +00:00
if __name__ == "__main__":
main()