import os import sys import traceback from collections import namedtuple import re import torch from PIL import Image from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared from modules import devices, paths blip_image_eval_size = 384 blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' clip_model_name = 'ViT-L/14' Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") class InterrogateModels: blip_model = None clip_model = None clip_preprocess = None categories = None def __init__(self, content_dir): self.categories = [] if os.path.exists(content_dir): for filename in os.listdir(content_dir): m = re_topn.search(filename) topn = 1 if m is None else int(m.group(1)) with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] self.categories.append(Category(name=filename, topn=topn, items=lines)) def load_blip_model(self): import models.blip blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json")) blip_model.eval() return blip_model def load_clip_model(self): import clip model, preprocess = clip.load(clip_model_name) model.eval() model = model.to(shared.device) return model, preprocess def load(self): if self.blip_model is None: self.blip_model = self.load_blip_model() self.blip_model = self.blip_model.to(shared.device) if self.clip_model is None: self.clip_model, self.clip_preprocess = self.load_clip_model() self.clip_model = self.clip_model.to(shared.device) def unload(self): if not shared.opts.interrogate_keep_models_in_memory: if self.clip_model is not None: self.clip_model = self.clip_model.to(devices.cpu) if self.blip_model is not None: self.blip_model = self.blip_model.to(devices.cpu) def rank(self, image_features, text_array, top_count=1): import clip top_count = min(top_count, len(text_array)) text_tokens = clip.tokenize([text for text in text_array]).cuda() with torch.no_grad(): text_features = self.clip_model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True) similarity = torch.zeros((1, len(text_array))).to(shared.device) for i in range(image_features.shape[0]): similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) similarity /= image_features.shape[0] top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] def generate_caption(self, pil_image): gpu_image = transforms.Compose([ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ])(pil_image).unsqueeze(0).to(shared.device) with torch.no_grad(): caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) return caption[0] def interrogate(self, pil_image): res = None try: self.load() caption = self.generate_caption(pil_image) res = caption images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device) with torch.no_grad(): image_features = self.clip_model.encode_image(images).float() image_features /= image_features.norm(dim=-1, keepdim=True) if shared.opts.interrogate_use_builtin_artists: artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] res += ", " + artist[0] for name, topn, items in self.categories: matches = self.rank(image_features, items, top_count=topn) for match, score in matches: res += ", " + match except Exception: print(f"Error interrogating", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) self.unload() return res