102 lines
2.9 KiB
Python
Executable File
102 lines
2.9 KiB
Python
Executable File
import torch
|
|
import torchaudio
|
|
import time
|
|
import logging
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
from torch import Tensor
|
|
from einops import rearrange
|
|
from pathlib import Path
|
|
|
|
from .utils import to_device, set_seed, wrapper as ml
|
|
from PIL import Image, ImageDraw
|
|
import torchvision.transforms as transforms
|
|
|
|
from .config import cfg, Config
|
|
from .models import get_models
|
|
from .engines import load_engines, deepspeed_available
|
|
from .data import get_symmap, tokenize
|
|
|
|
if deepspeed_available:
|
|
import deepspeed
|
|
|
|
class Classifier():
|
|
def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
|
self.loading = True
|
|
|
|
# yes I can just grab **kwargs and forward them here
|
|
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
|
|
self.load_model()
|
|
|
|
self.loading = False
|
|
|
|
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
|
if config:
|
|
_logger.info(f"Loading YAML: {config}")
|
|
cfg.load_yaml( config )
|
|
|
|
try:
|
|
cfg.format( training=False )
|
|
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
|
except Exception as e:
|
|
raise e # throw an error because I'm tired of silent errors messing things up for me
|
|
|
|
if amp is None:
|
|
amp = cfg.inference.amp
|
|
if dtype is None or dtype == "auto":
|
|
dtype = cfg.inference.weight_dtype
|
|
if device is None:
|
|
device = cfg.device
|
|
|
|
cfg.device = device
|
|
cfg.mode = "inferencing"
|
|
cfg.trainer.backend = cfg.inference.backend
|
|
cfg.trainer.weight_dtype = dtype
|
|
cfg.inference.weight_dtype = dtype
|
|
|
|
self.device = device
|
|
self.dtype = cfg.inference.dtype
|
|
self.amp = amp
|
|
|
|
self.model_kwargs = {}
|
|
|
|
def load_model( self ):
|
|
load_engines.cache_clear()
|
|
|
|
self.engines = load_engines(training=False, **self.model_kwargs)
|
|
for name, engine in self.engines.items():
|
|
if self.dtype != torch.int8:
|
|
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
|
|
|
self.engines.eval()
|
|
self.symmap = get_symmap()
|
|
|
|
self.width = 300
|
|
self.height = 80
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize((self.height, self.width)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
_logger.info("Loaded model")
|
|
|
|
@torch.inference_mode()
|
|
def inference( self, image, temperature=1.0 ):
|
|
model = None
|
|
|
|
for name, engine in self.engines.items():
|
|
model = engine.module
|
|
break
|
|
|
|
image = self.transform(image).to(self.device).to(self.dtype)
|
|
|
|
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
|
|
answer = model( image=[image], sampling_temperature=temperature )
|
|
|
|
answer = [ "".join(answer) ]
|
|
|
|
answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed
|
|
|
|
return answer |