import torch from PIL import Image import torchvision.transforms as transforms from .config import cfg from .export import load_models class CAPTCHA(): def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ): self.loading = True self.device = device if ckpt: self.load_model_from_ckpt( ckpt ) else: self.load_model_from_cfg( config ) self.model.eval() self.width = width self.height = height self.transform = transforms.Compose([ transforms.Resize((self.height, self.width)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) self.loading = False def load_model_from_ckpt( self, ckpt ): self.ckpt = ckpt self.model = torch.load(self.ckpt).to(self.device) def load_model_from_cfg( self, config_path ): if config_path: cfg.load_yaml( config_path ) models = load_models() for name in models: model = models[name] self.model = model.to(self.device) break def inference( self, path, temperature=1.0 ): image = self.transform(Image.open(path).convert('RGB')).to(self.device) answer = self.model( image=[image], sampling_temperature=temperature ) answer = answer[0].replace('', "").replace("", "") return answer