49 lines
1.3 KiB
Python
Executable File
49 lines
1.3 KiB
Python
Executable File
import torch
|
|
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
|
|
from .config import cfg
|
|
from .export import load_models
|
|
|
|
class CAPTCHA():
|
|
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
|
|
self.loading = True
|
|
self.device = device
|
|
|
|
if ckpt:
|
|
self.load_model_from_ckpt( ckpt )
|
|
else:
|
|
self.load_model_from_cfg( config )
|
|
|
|
self.width = width
|
|
self.height = height
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize((self.height, self.width)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
self.loading = False
|
|
|
|
def load_model_from_ckpt( self, ckpt ):
|
|
self.ckpt = ckpt
|
|
|
|
self.model = torch.load(self.ckpt).to(self.device)
|
|
|
|
def load_model_from_cfg( self, config_path ):
|
|
if config_path:
|
|
cfg.load_yaml( config_path )
|
|
|
|
models = load_models()
|
|
for name in models:
|
|
model = models[name]
|
|
self.model = model.to(self.device)
|
|
break
|
|
|
|
def inference( self, path, temperature=1.0 ):
|
|
image = self.transform(Image.open(path).convert('RGB')).to(self.device)
|
|
answer = self.model( image=[image], sampling_temperature=temperature )
|
|
answer = answer[0].replace('<s>', "").replace("</s>", "")
|
|
|
|
return answer |