Pinwheel's picture
HF Demo
128757a
raw
history blame contribute delete
No virus
5.14 kB
import sys
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from models.blip_vqa import blip_vqa
import cv2
import numpy as np
import matplotlib.image as mpimg
from skimage import transform as skimage_transform
from scipy.ndimage import filters
from matplotlib import pyplot as plt
import torch
from torch import nn
from torchvision import transforms
import json
import traceback
class VQA:
def __init__(self, model_path, image_size=480):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = blip_vqa(pretrained=model_path, image_size=image_size, vit='base')
self.block_num = 9
self.model.eval()
self.model.text_encoder.base_model.base_model.encoder.layer[self.block_num].crossattention.self.save_attention = True
self.model = self.model.to(self.device)
def getAttMap(self, img, attMap, blur = True, overlap = True):
attMap -= attMap.min()
if attMap.max() > 0:
attMap /= attMap.max()
attMap = skimage_transform.resize(attMap, (img.shape[:2]), order = 3, mode = 'constant')
if blur:
attMap = filters.gaussian_filter(attMap, 0.02*max(img.shape[:2]))
attMap -= attMap.min()
attMap /= attMap.max()
cmap = plt.get_cmap('jet')
attMapV = cmap(attMap)
attMapV = np.delete(attMapV, 3, 2)
if overlap:
attMap = 1*(1-attMap**0.7).reshape(attMap.shape + (1,))*img + (attMap**0.7).reshape(attMap.shape+(1,)) * attMapV
return attMap
def gradcam(self, text_input, image_path, image):
mask = text_input.attention_mask.view(text_input.attention_mask.size(0),1,-1,1,1)
grads = self.model.text_encoder.base_model.base_model.encoder.layer[self.block_num].crossattention.self.get_attn_gradients()
cams = self.model.text_encoder.base_model.base_model.encoder.layer[self.block_num].crossattention.self.get_attention_map()
cams = cams[:, :, :, 1:].reshape(image.size(0), 12, -1, 30, 30) * mask
grads = grads[:, :, :, 1:].clamp(0).reshape(image.size(0), 12, -1, 30, 30) * mask
gradcam = cams * grads
gradcam = gradcam[0].mean(0).cpu().detach()
num_image = len(text_input.input_ids[0])
num_image -= 1
fig, ax = plt.subplots(num_image, 1, figsize=(15,15*num_image))
rgb_image = cv2.imread(image_path)[:, :, ::-1]
rgb_image = np.float32(rgb_image) / 255
ax[0].imshow(rgb_image)
ax[0].set_yticks([])
ax[0].set_xticks([])
ax[0].set_xlabel("Image")
for i,token_id in enumerate(text_input.input_ids[0][1:-1]):
word = self.model.tokenizer.decode([token_id])
gradcam_image = self.getAttMap(rgb_image, gradcam[i+1])
ax[i+1].imshow(gradcam_image)
ax[i+1].set_yticks([])
ax[i+1].set_xticks([])
ax[i+1].set_xlabel(word)
plt.show()
def load_demo_image(self, image_size, img_path, device):
raw_image = Image.open(img_path).convert('RGB')
w,h = raw_image.size
transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
image = transform(raw_image).unsqueeze(0).to(device)
return raw_image, image
def vqa(self, img_path, question):
raw_image, image = self.load_demo_image(image_size=480, img_path=img_path, device=self.device)
answer, vl_output, que = self.model(image, question, mode='gradcam', inference='generate')
loss = vl_output[:,1].sum()
self.model.zero_grad()
loss.backward()
with torch.no_grad():
self.gradcam(que, img_path, image)
return answer[0]
def vqa_demo(self, image, question):
image_size = 480
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
image = transform(image).unsqueeze(0).to(self.device)
answer = self.model(image, question, mode='inference', inference='generate')
return answer[0]
if __name__=="__main__":
if not len(sys.argv) == 3:
print('Format: python3 vqa.py <path_to_img> <question>')
print('Sample: python3 vqa.py sample.jpg "What is the color of the horse?"')
else:
model_path = 'checkpoints/model_base_vqa_capfilt_large.pth'
vqa_object = VQA(model_path=model_path)
img_path = sys.argv[1]
question = sys.argv[2]
answer = vqa_object.vqa(img_path, question)
print('Question: {} | Answer: {}'.format(question, answer))