import torch import torchvision.transforms as transforms from .config import cfg from .export import load_models from .data import get_symmap, _get_symbols class CAPTCHA(): def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ): self.loading = True self.device = device if config: cfg.load_yaml( config ) if ckpt: self.load_model_from_ckpt( ckpt ) else: self.load_model_from_cfg( config ) self.model.eval() self.width = width self.height = height self.transform = transforms.Compose([ transforms.Resize((self.height, self.width)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) self.loading = False def load_model_from_ckpt( self, ckpt ): self.ckpt = ckpt self.model = torch.load(self.ckpt).to(self.device) def load_model_from_cfg( self, config_path ): models = load_models() for name in models: model = models[name] self.model = model.to(self.device) break def inference( self, image, temperature=1.0 ): image = self.transform(image).to(self.device) answer = self.model( image=[image], sampling_temperature=temperature ) answer = answer[0].replace('', "").replace("", "") return answer