resnet-classifier/image_classifier/inference.py
2024-09-04 15:57:32 -05:00

102 lines
2.9 KiB
Python
Executable File

import torch
import torchaudio
import time
import logging
_logger = logging.getLogger(__name__)
from torch import Tensor
from einops import rearrange
from pathlib import Path
from .utils import to_device, set_seed, wrapper as ml
from PIL import Image, ImageDraw
import torchvision.transforms as transforms
from .config import cfg, Config
from .models import get_models
from .engines import load_engines, deepspeed_available
from .data import get_symmap, tokenize
if deepspeed_available:
import deepspeed
class Classifier():
def __init__( self, config=None, device=None, amp=None, dtype=None, attention=None ):
self.loading = True
# yes I can just grab **kwargs and forward them here
self.load_config( config=config, device=device, amp=amp, dtype=dtype, attention=attention )
self.load_model()
self.loading = False
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
if config:
_logger.info(f"Loading YAML: {config}")
cfg.load_yaml( config )
try:
cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
except Exception as e:
raise e # throw an error because I'm tired of silent errors messing things up for me
if amp is None:
amp = cfg.inference.amp
if dtype is None or dtype == "auto":
dtype = cfg.inference.weight_dtype
if device is None:
device = cfg.device
cfg.device = device
cfg.mode = "inferencing"
cfg.trainer.backend = cfg.inference.backend
cfg.trainer.weight_dtype = dtype
cfg.inference.weight_dtype = dtype
self.device = device
self.dtype = cfg.inference.dtype
self.amp = amp
self.model_kwargs = {}
def load_model( self ):
load_engines.cache_clear()
self.engines = load_engines(training=False, **self.model_kwargs)
for name, engine in self.engines.items():
if self.dtype != torch.int8:
engine.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.engines.eval()
self.symmap = get_symmap()
self.width = 300
self.height = 80
self.transform = transforms.Compose([
transforms.Resize((self.height, self.width)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
_logger.info("Loaded model")
@torch.inference_mode()
def inference( self, image, temperature=1.0 ):
model = None
for name, engine in self.engines.items():
model = engine.module
break
image = self.transform(image).to(self.device).to(self.dtype)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
answer = model( image=[image], sampling_temperature=temperature )
answer = [ "".join(answer) ]
answer = answer[0].replace('<s>', "").replace("</s>", "") # it would be better to just slice between these, but I can't be assed
return answer