Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +27 -4
my_model/KBVQA.py
CHANGED
@@ -222,7 +222,22 @@ class KBVQA:
|
|
222 |
p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
|
223 |
|
224 |
return p
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
|
228 |
"""
|
@@ -236,13 +251,21 @@ class KBVQA:
|
|
236 |
Returns:
|
237 |
str: The generated answer to the question.
|
238 |
"""
|
|
|
|
|
239 |
free_gpu_resources()
|
240 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
241 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
242 |
self.current_prompt_length = num_tokens
|
243 |
-
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
248 |
free_gpu_resources()
|
|
|
222 |
p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
|
223 |
|
224 |
return p
|
225 |
+
|
226 |
+
@staticmethod
|
227 |
+
def trim_objects(self, detected_objects_str):
|
228 |
+
"""
|
229 |
+
Trim the last object from the detected objects string.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
- detected_objects_str (str): String containing detected objects.
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
- (str): The string with the last object removed.
|
236 |
+
"""
|
237 |
+
objects = detected_objects_str.strip().split("\n")
|
238 |
+
if len(objects) >= 1:
|
239 |
+
return "\n".join(objects[:-1])
|
240 |
+
return ""
|
241 |
|
242 |
def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
|
243 |
"""
|
|
|
251 |
Returns:
|
252 |
str: The generated answer to the question.
|
253 |
"""
|
254 |
+
|
255 |
+
|
256 |
free_gpu_resources()
|
257 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
258 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
259 |
self.current_prompt_length = num_tokens
|
260 |
+
|
261 |
+
while self.current_prompt_length > self.max_context_window:
|
262 |
+
detected_objects_str = self.trim_objects(detected_objects_str)
|
263 |
+
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
264 |
+
self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
|
265 |
+
|
266 |
+
if detected_objects_str == "":
|
267 |
+
break # Break if no objects are left
|
268 |
+
|
269 |
|
270 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
271 |
free_gpu_resources()
|