import torch import torchaudio import time import logging _logger = logging.getLogger(__name__) from torch import Tensor from einops import rearrange from pathlib import Path from .utils import to_device, set_seed, wrapper as ml from PIL import Image, ImageDraw import torchvision.transforms as transforms from .config import cfg, Config from .models import get_models from .engines import load_engines, deepspeed_available from .data import get_symmap, tokenize if deepspeed_available: import deepspeed class Classifier(): def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ): self.loading = True # yes I can just grab **kwargs and forward them here self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention ) self.load_model() self.loading = False def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ): if config: _logger.info(f"Loading YAML: {config}") cfg.load_yaml( config ) try: cfg.format( training=False ) cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing except Exception as e: raise e # throw an error because I'm tired of silent errors messing things up for me if amp is None: amp = cfg.inference.amp if dtype is None or dtype == "auto": dtype = cfg.inference.weight_dtype if device is None: device = cfg.device cfg.device = device cfg.mode = "inferencing" cfg.trainer.backend = cfg.inference.backend cfg.trainer.weight_dtype = dtype cfg.inference.weight_dtype = dtype self.device = device self.dtype = cfg.inference.dtype self.amp = amp self.model_kwargs = {} def load_model( self ): load_engines.cache_clear() self.engines = load_engines(training=False, **self.model_kwargs) for name, engine in self.engines.items(): if self.dtype != torch.int8: engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.engines.eval() self.symmap = get_symmap() self.width = 300 self.height = 80 self.transform = transforms.Compose([ transforms.Resize((self.height, self.width)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) _logger.info("Loaded model") @torch.inference_mode() def inference( self, image, temperature=1.0 ): model = None for name, engine in self.engines.items(): model = engine.module break image = self.transform(image).to(self.device).to(self.dtype) with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): answer = model( image=[image], sampling_temperature=temperature ) answer = [ "".join(answer) ] answer = answer[0].replace('', "").replace("", "") # it would be better to just slice between these, but I can't be assed return answer