Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +11 -3
my_model/KBVQA.py
CHANGED
@@ -176,12 +176,16 @@ class KBVQA:
|
|
176 |
free_gpu_resources()
|
177 |
if self.kbvqa_model is not None:
|
178 |
del self.kbvqa_model
|
|
|
179 |
if self.captioner is not None:
|
180 |
del self.captioner
|
|
|
181 |
if self.detector is not None:
|
182 |
del self.detector
|
183 |
-
|
|
|
184 |
free_gpu_resources()
|
|
|
185 |
|
186 |
|
187 |
def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str:
|
@@ -253,7 +257,7 @@ class KBVQA:
|
|
253 |
|
254 |
return output_text.capitalize()
|
255 |
|
256 |
-
def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
|
257 |
"""
|
258 |
Prepares the KBVQA model for use, including loading necessary sub-models.
|
259 |
|
@@ -269,7 +273,11 @@ def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
|
|
269 |
kbvqa.detection_model = st.session_state.detection_model
|
270 |
# Progress bar for model loading
|
271 |
with kbvqa.col1:
|
272 |
-
|
|
|
|
|
|
|
|
|
273 |
if not only_reload_detection_model:
|
274 |
progress_bar = st.progress(0)
|
275 |
kbvqa.load_detector(kbvqa.detection_model)
|
|
|
176 |
free_gpu_resources()
|
177 |
if self.kbvqa_model is not None:
|
178 |
del self.kbvqa_model
|
179 |
+
free_gpu_resources()
|
180 |
if self.captioner is not None:
|
181 |
del self.captioner
|
182 |
+
free_gpu_resources()
|
183 |
if self.detector is not None:
|
184 |
del self.detector
|
185 |
+
free_gpu_resources()
|
186 |
+
|
187 |
free_gpu_resources()
|
188 |
+
prepare_kbvqa_model(only_reload_detection_model=False, force_reload=True)
|
189 |
|
190 |
|
191 |
def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str:
|
|
|
257 |
|
258 |
return output_text.capitalize()
|
259 |
|
260 |
+
def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = True) -> KBVQA:
|
261 |
"""
|
262 |
Prepares the KBVQA model for use, including loading necessary sub-models.
|
263 |
|
|
|
273 |
kbvqa.detection_model = st.session_state.detection_model
|
274 |
# Progress bar for model loading
|
275 |
with kbvqa.col1:
|
276 |
+
if force_reload:
|
277 |
+
loading_message = 'Force Reloading model.. this should take no more than a few minutes!'
|
278 |
+
else: loading_message = 'Looading model.. this should take no more than a few minutes!'
|
279 |
+
|
280 |
+
with st.spinner(loading_message):
|
281 |
if not only_reload_detection_model:
|
282 |
progress_bar = st.progress(0)
|
283 |
kbvqa.load_detector(kbvqa.detection_model)
|