72 lines
2.3 KiB
Python
Executable File
72 lines
2.3 KiB
Python
Executable File
import argparse
|
|
import base64
|
|
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from .inference import Classifier
|
|
|
|
from PIL import Image
|
|
from simple_http_server import route, server
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
parser.add_argument("--listen", action='store_true')
|
|
parser.add_argument("--port", type=int, default=9090)
|
|
|
|
parser.add_argument("--yaml", type=Path, default=None)
|
|
parser.add_argument("--device", type=str, default=None)
|
|
parser.add_argument("--amp", action="store_true")
|
|
parser.add_argument("--dtype", type=str, default=None)
|
|
|
|
parser.add_argument("--temp", type=float, default=0.0)
|
|
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
classifier = Classifier( config=args.yaml, device=args.device, dtype=args.dtype, amp=args.amp )
|
|
if args.listen:
|
|
@route("/")
|
|
def inference( b64, temperature=args.temp ):
|
|
image = Image.open(BytesIO(base64.b64decode(b64))).convert("RGB")
|
|
return { "answer": classifier.inference( image=image, temperature=temperature ) }
|
|
server.start(port=args.port)
|
|
else:
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
parser.add_argument("--path", type=Path)
|
|
parser.add_argument("--base64", type=str)
|
|
parser.add_argument("--write", type=Path)
|
|
parser.add_argument("--temp", type=float, default=1.0)
|
|
parser.add_argument("--limit", type=int, default=0)
|
|
args, unknown = parser.parse_known_args()
|
|
|
|
images = []
|
|
if args.path:
|
|
if args.path.is_dir():
|
|
for p in args.path.rglob("./*.jpg"):
|
|
image = Image.open(p).convert('RGB')
|
|
images.append(image)
|
|
if args.limit and len(images) >= args.limit:
|
|
break
|
|
for p in args.path.rglob("./*.png"):
|
|
image = Image.open(p).convert('RGB')
|
|
images.append(image)
|
|
if args.limit and len(images) >= args.limit:
|
|
break
|
|
else:
|
|
image = Image.open(args.path).convert('RGB')
|
|
images.append(image)
|
|
elif args.base64:
|
|
image = Image.open(BytesIO(base64.b64decode(args.base64))).convert("RGB")
|
|
images.append(image)
|
|
else:
|
|
raise "Specify a --path or --base64."
|
|
|
|
for image in images:
|
|
answer = classifier.inference( image=image, temperature=args.temp )
|
|
print("Answer:", answer)
|
|
if args.write:
|
|
args.write.mkdir(exist_ok=True)
|
|
image.save( args.write / f"{answer}.jpg")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|