Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +7 -4
my_model/KBVQA.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
|
|
3 |
import os
|
4 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
5 |
from typing import Optional
|
@@ -141,10 +142,12 @@ class KBVQA():
|
|
141 |
|
142 |
|
143 |
def generate_answer(self, question, image):
|
144 |
-
|
145 |
-
st.
|
146 |
-
|
147 |
-
|
|
|
|
|
148 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
149 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
150 |
if num_tokens > self.max_context_window:
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
import copy
|
4 |
import os
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
6 |
from typing import Optional
|
|
|
142 |
|
143 |
|
144 |
def generate_answer(self, question, image):
|
145 |
+
img = copy.deepcopy(image)
|
146 |
+
st.write('image being detcted')
|
147 |
+
st.image(img)
|
148 |
+
caption = self.get_caption(img)
|
149 |
+
image_with_boxes, detected_objects_str = self.detect_objects(img)
|
150 |
+
st.write(detected_objects_str)
|
151 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
152 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
153 |
if num_tokens > self.max_context_window:
|