Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +20 -1
my_model/KBVQA.py
CHANGED
@@ -21,6 +21,7 @@ class KBVQA():
|
|
21 |
self.kbvqa_tokenizer = None
|
22 |
self.captioner = None
|
23 |
self.detector = None
|
|
|
24 |
self.kbvqa_model = None
|
25 |
self.access_token = os.getenv("HUGGINGFACE_TOKEN")
|
26 |
# self.kbvqa_model_loaded = self.all_models_loaded()
|
@@ -87,6 +88,22 @@ class KBVQA():
|
|
87 |
def all_models_loaded(self):
|
88 |
return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
|
92 |
def format_prompt(self, current_query, history = None , sys_prompt=None, caption=None, objects=None):
|
@@ -144,11 +161,13 @@ class KBVQA():
|
|
144 |
def prepare_kbvqa_model(detection_model):
|
145 |
free_gpu_resources()
|
146 |
kbvqa = KBVQA()
|
|
|
147 |
# Progress bar for model loading
|
148 |
with st.spinner('Loading model...'):
|
149 |
|
150 |
progress_bar = st.progress(0)
|
151 |
-
|
|
|
152 |
progress_bar.progress(33)
|
153 |
kbvqa.load_caption_model()
|
154 |
free_gpu_resources()
|
|
|
21 |
self.kbvqa_tokenizer = None
|
22 |
self.captioner = None
|
23 |
self.detector = None
|
24 |
+
sel.detection_model = None
|
25 |
self.kbvqa_model = None
|
26 |
self.access_token = os.getenv("HUGGINGFACE_TOKEN")
|
27 |
# self.kbvqa_model_loaded = self.all_models_loaded()
|
|
|
88 |
def all_models_loaded(self):
|
89 |
return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
|
90 |
|
91 |
+
def force_reload_model(self):
|
92 |
+
free_gpu_resources()
|
93 |
+
if self.kbvqa_model is not None:
|
94 |
+
del self.kbvqa_model
|
95 |
+
if self.captioner is not None:
|
96 |
+
del self.captioner
|
97 |
+
if self.detector is not None:
|
98 |
+
del self.detector
|
99 |
+
|
100 |
+
free_gpu_resources()
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
|
108 |
|
109 |
def format_prompt(self, current_query, history = None , sys_prompt=None, caption=None, objects=None):
|
|
|
161 |
def prepare_kbvqa_model(detection_model):
|
162 |
free_gpu_resources()
|
163 |
kbvqa = KBVQA()
|
164 |
+
kbvqa.detection_model = detection_model
|
165 |
# Progress bar for model loading
|
166 |
with st.spinner('Loading model...'):
|
167 |
|
168 |
progress_bar = st.progress(0)
|
169 |
+
|
170 |
+
kbvqa.load_detector(kbvqa.detection_model)
|
171 |
progress_bar.progress(33)
|
172 |
kbvqa.load_caption_model()
|
173 |
free_gpu_resources()
|