[CLIP interrogator] use local file, if available
This commit is contained in:
parent
98947d173e
commit
745f1e8f80
|
@ -14,6 +14,7 @@ import modules.shared as shared
|
||||||
from modules import devices, paths, lowvram
|
from modules import devices, paths, lowvram
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
|
blip_model_local = os.path.join('models', 'Interrogator', 'BLIP_model.pth')
|
||||||
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
|
||||||
clip_model_name = 'ViT-L/14'
|
clip_model_name = 'ViT-L/14'
|
||||||
|
|
||||||
|
@ -47,7 +48,13 @@ class InterrogateModels:
|
||||||
def load_blip_model(self):
|
def load_blip_model(self):
|
||||||
import models.blip
|
import models.blip
|
||||||
|
|
||||||
blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
if not os.path.isfile(blip_model_local):
|
||||||
|
print("Downloading BLIP...")
|
||||||
|
import requests as req
|
||||||
|
open(blip_model_local, 'wb').write(req.get(blip_model_url, allow_redirects=True).content)
|
||||||
|
print("BLIP downloaded to", blip_model_local + '.')
|
||||||
|
|
||||||
|
blip_model = models.blip.blip_decoder(pretrained=blip_model_local, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
|
||||||
blip_model.eval()
|
blip_model.eval()
|
||||||
|
|
||||||
return blip_model
|
return blip_model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user