import argparse import base64 from io import BytesIO from pathlib import Path from .inference import Classifier from PIL import Image from simple_http_server import route, server def main(): parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--listen", action='store_true') parser.add_argument("--port", type=int, default=9090) parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--device", type=str, default=None) parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default=None) parser.add_argument("--temp", type=float, default=0.0) args, unknown = parser.parse_known_args() classifier = Classifier( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp ) if args.listen: @route("/") def inference( b64, temperature=args.temp ): image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB") return { "answer": classifier.inference( image=image, temperature=temperature ) } server.start(port=args.port) else: parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--path", type=Path) parser.add_argument("--base64", type=str) parser.add_argument("--write", type=Path) parser.add_argument("--temp", type=float, default=1.0) args, unknown = parser.parse_known_args() images = [] if args.path: if args.path.is_dir(): for p in args.path.rglob("./*.jpg"): image = Image.open(p).convert('RGB') images.append(image) for p in args.path.rglob("./*.png"): image = Image.open(p).convert('RGB') images.append(image) else: image = Image.open(args.path).convert('RGB') images.append(image) elif args.base64: image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB") images.append(image) else: raise "Specify a --path or --base64." for image in images: answer = classifier.inference( image=image, temperature=args.temp ) print("Answer:", answer) if args.write: args.write.mkdir(exist_ok=True) image.save( args.write / f"{answer}.jpg") if __name__ == "__main__": main()