resnet-classifier/image_classifier/inference.py

51 lines
1.3 KiB
Python
Executable File

import torch
from PIL import Image
import torchvision.transforms as transforms
from .config import cfg
from .export import load_models
class CAPTCHA():
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
self.loading = True
self.device = device
if ckpt:
self.load_model_from_ckpt( ckpt )
else:
self.load_model_from_cfg( config )
self.model.eval()
self.width = width
self.height = height
self.transform = transforms.Compose([
transforms.Resize((self.height, self.width)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.loading = False
def load_model_from_ckpt( self, ckpt ):
self.ckpt = ckpt
self.model = torch.load(self.ckpt).to(self.device)
def load_model_from_cfg( self, config_path ):
if config_path:
cfg.load_yaml( config_path )
models = load_models()
for name in models:
model = models[name]
self.model = model.to(self.device)
break
def inference( self, path, temperature=1.0 ):
image = self.transform(Image.open(path).convert('RGB')).to(self.device)
answer = self.model( image=[image], sampling_temperature=temperature )
answer = answer[0].replace('<s>', "").replace("</s>", "")
return answer