resnet-classifier/image_classifier/inference.py

53 lines
1.3 KiB
Python
Executable File

import torch
import torchvision.transforms as transforms
from .config import cfg
from .export import load_models
from .data import get_symmap, _get_symbols
class CAPTCHA():
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
self.loading = True
self.device = device
if config:
cfg.load_yaml( config )
if ckpt:
self.load_model_from_ckpt( ckpt )
else:
self.load_model_from_cfg( config )
self.model.eval()
self.width = width
self.height = height
self.transform = transforms.Compose([
transforms.Resize((self.height, self.width)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
self.loading = False
def load_model_from_ckpt( self, ckpt ):
self.ckpt = ckpt
self.model = torch.load(self.ckpt).to(self.device)
def load_model_from_cfg( self, config_path ):
models = load_models()
for name in models:
model = models[name]
self.model = model.to(self.device)
break
def inference( self, image, temperature=1.0 ):
image = self.transform(image).to(self.device)
answer = self.model( image=[image], sampling_temperature=temperature )
answer = answer[0].replace('<s>', "").replace("</s>", "")
return answer