Update my_model/tabs/run_inference.py
Browse files- my_model/tabs/run_inference.py +39 -25
my_model/tabs/run_inference.py
CHANGED
@@ -16,14 +16,36 @@ from my_model.config import inference_config as config
|
|
16 |
|
17 |
|
18 |
class InferenceRunner(StateManager):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def __init__(self):
|
|
|
|
|
|
|
20 |
|
21 |
super().__init__()
|
22 |
self.initialize_state()
|
23 |
-
self.sample_images = config.SAMPLE_IMAGES
|
24 |
|
25 |
|
26 |
def answer_question(self, caption, detected_objects_str, question, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
free_gpu_resources()
|
28 |
answer = model.generate_answer(question, caption, detected_objects_str)
|
29 |
free_gpu_resources()
|
@@ -31,10 +53,18 @@ class InferenceRunner(StateManager):
|
|
31 |
|
32 |
|
33 |
def image_qa_app(self, kbvqa):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Display sample images as clickable thumbnails
|
35 |
self.col1.write("Choose from sample images:")
|
36 |
-
cols = self.col1.columns(len(
|
37 |
-
for idx, sample_image_path in enumerate(
|
38 |
with cols[idx]:
|
39 |
image = Image.open(sample_image_path)
|
40 |
image_for_display = self.resize_image(sample_image_path, 80, 80)
|
@@ -42,13 +72,8 @@ class InferenceRunner(StateManager):
|
|
42 |
if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
|
43 |
self.process_new_image(sample_image_path, image, kbvqa)
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
# Image uploader
|
49 |
uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
|
50 |
-
|
51 |
-
|
52 |
if uploaded_image is not None:
|
53 |
self.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
|
54 |
|
@@ -67,7 +92,6 @@ class InferenceRunner(StateManager):
|
|
67 |
caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
|
68 |
self.update_image_data(image_key, caption, detected_objects_str, True)
|
69 |
st.session_state['loading_in_progress'] = False
|
70 |
-
|
71 |
|
72 |
# Initialize qa_history for each image
|
73 |
qa_history = image_data.get('qa_history', [])
|
@@ -87,7 +111,6 @@ class InferenceRunner(StateManager):
|
|
87 |
# Use the selected sample question or the custom question
|
88 |
question = custom_question if selected_question == "Custom question..." else selected_question
|
89 |
|
90 |
-
|
91 |
if not question:
|
92 |
nested_col22.warning("Please select or enter a question.")
|
93 |
else:
|
@@ -100,20 +123,19 @@ class InferenceRunner(StateManager):
|
|
100 |
st.session_state['loading_in_progress'] = False
|
101 |
self.add_to_qa_history(image_key, question, answer)
|
102 |
|
103 |
-
|
104 |
# Display Q&A history for each image
|
105 |
for num, (q, a) in enumerate(qa_history):
|
106 |
nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\n")
|
107 |
|
108 |
-
def display_message(self, message, warning=False, write=False, text=False):
|
109 |
-
pass
|
110 |
-
|
111 |
-
|
112 |
|
113 |
def run_inference(self):
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
self.set_up_widgets()
|
116 |
-
|
117 |
load_fine_tuned_model = False
|
118 |
fine_tuned_model_already_loaded = False
|
119 |
reload_detection_model = False
|
@@ -123,27 +145,21 @@ class InferenceRunner(StateManager):
|
|
123 |
if st.session_state['settings_changed']:
|
124 |
self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
|
125 |
|
126 |
-
|
127 |
st.session_state.button_label = "Reload Model" if self.is_model_loaded() and self.settings_changed else "Load Model"
|
128 |
|
129 |
with self.col1:
|
130 |
-
|
131 |
if st.session_state.method == "Fine-Tuned Model":
|
132 |
-
|
133 |
with st.container():
|
134 |
nested_col11, nested_col12 = st.columns([0.5, 0.5])
|
135 |
if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
136 |
-
|
137 |
if st.session_state.button_label == "Load Model":
|
138 |
if self.is_model_loaded():
|
139 |
free_gpu_resources()
|
140 |
fine_tuned_model_already_loaded = True
|
141 |
-
|
142 |
else:
|
143 |
load_fine_tuned_model = True
|
144 |
else:
|
145 |
reload_detection_model = True
|
146 |
-
|
147 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
148 |
force_reload_full_model = True
|
149 |
|
@@ -172,14 +188,12 @@ class InferenceRunner(StateManager):
|
|
172 |
st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
|
173 |
st.session_state['loading_in_progress'] = False
|
174 |
st.session_state['model_loaded'] = True
|
175 |
-
|
176 |
elif st.session_state.method == "In-Context Learning (n-shots)":
|
177 |
self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
|
178 |
st.session_state['loading_in_progress'] = False
|
179 |
-
|
180 |
|
181 |
if self.is_model_loaded():
|
182 |
-
|
183 |
free_gpu_resources()
|
184 |
st.session_state['loading_in_progress'] = False
|
185 |
self.image_qa_app(self.get_model())
|
|
|
16 |
|
17 |
|
18 |
class InferenceRunner(StateManager):
|
19 |
+
|
20 |
+
"""
|
21 |
+
InferenceRunner manages the user interface and interactions for a Streamlit-based
|
22 |
+
Knowledge-Based Visual Question Answering (KBVQA) application. It handles image uploads,
|
23 |
+
displays sample images, and facilitates the question-answering process using the KBVQA model.
|
24 |
+
it inherits the StateManager class.
|
25 |
+
"""
|
26 |
+
|
27 |
def __init__(self):
|
28 |
+
"""
|
29 |
+
Initializes the InferenceRunner instance, setting up the necessary state.
|
30 |
+
"""
|
31 |
|
32 |
super().__init__()
|
33 |
self.initialize_state()
|
|
|
34 |
|
35 |
|
36 |
def answer_question(self, caption, detected_objects_str, question, model):
|
37 |
+
"""
|
38 |
+
Generates an answer to a given question based on the image's caption and detected objects.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
caption (str): The caption generated for the image.
|
42 |
+
detected_objects_str (str): String representation of objects detected in the image.
|
43 |
+
question (str): The user's question about the image.
|
44 |
+
model (KBVQA): The loaded KBVQA model used for generating the answer.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
str: The generated answer to the question.
|
48 |
+
"""
|
49 |
free_gpu_resources()
|
50 |
answer = model.generate_answer(question, caption, detected_objects_str)
|
51 |
free_gpu_resources()
|
|
|
53 |
|
54 |
|
55 |
def image_qa_app(self, kbvqa):
|
56 |
+
"""
|
57 |
+
Main application interface for image-based question answering. It handles displaying
|
58 |
+
of sample images, uploading of new images, and facilitates the QA process.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
kbvqa (KBVQA): The loaded KBVQA model used for image analysis and question answering.
|
62 |
+
"""
|
63 |
+
|
64 |
# Display sample images as clickable thumbnails
|
65 |
self.col1.write("Choose from sample images:")
|
66 |
+
cols = self.col1.columns(len(config.SAMPLE_IMAGES))
|
67 |
+
for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
|
68 |
with cols[idx]:
|
69 |
image = Image.open(sample_image_path)
|
70 |
image_for_display = self.resize_image(sample_image_path, 80, 80)
|
|
|
72 |
if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
|
73 |
self.process_new_image(sample_image_path, image, kbvqa)
|
74 |
|
|
|
|
|
|
|
75 |
# Image uploader
|
76 |
uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
|
|
|
|
|
77 |
if uploaded_image is not None:
|
78 |
self.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
|
79 |
|
|
|
92 |
caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
|
93 |
self.update_image_data(image_key, caption, detected_objects_str, True)
|
94 |
st.session_state['loading_in_progress'] = False
|
|
|
95 |
|
96 |
# Initialize qa_history for each image
|
97 |
qa_history = image_data.get('qa_history', [])
|
|
|
111 |
# Use the selected sample question or the custom question
|
112 |
question = custom_question if selected_question == "Custom question..." else selected_question
|
113 |
|
|
|
114 |
if not question:
|
115 |
nested_col22.warning("Please select or enter a question.")
|
116 |
else:
|
|
|
123 |
st.session_state['loading_in_progress'] = False
|
124 |
self.add_to_qa_history(image_key, question, answer)
|
125 |
|
|
|
126 |
# Display Q&A history for each image
|
127 |
for num, (q, a) in enumerate(qa_history):
|
128 |
nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\n")
|
129 |
|
|
|
|
|
|
|
|
|
130 |
|
131 |
def run_inference(self):
|
132 |
+
"""
|
133 |
+
Sets up the widgets and manages the inference process. This method handles model loading,
|
134 |
+
reloading, and the overall flow of the inference process based on user interactions.
|
135 |
+
|
136 |
+
"""
|
137 |
|
138 |
self.set_up_widgets()
|
|
|
139 |
load_fine_tuned_model = False
|
140 |
fine_tuned_model_already_loaded = False
|
141 |
reload_detection_model = False
|
|
|
145 |
if st.session_state['settings_changed']:
|
146 |
self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
|
147 |
|
|
|
148 |
st.session_state.button_label = "Reload Model" if self.is_model_loaded() and self.settings_changed else "Load Model"
|
149 |
|
150 |
with self.col1:
|
|
|
151 |
if st.session_state.method == "Fine-Tuned Model":
|
|
|
152 |
with st.container():
|
153 |
nested_col11, nested_col12 = st.columns([0.5, 0.5])
|
154 |
if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
|
|
155 |
if st.session_state.button_label == "Load Model":
|
156 |
if self.is_model_loaded():
|
157 |
free_gpu_resources()
|
158 |
fine_tuned_model_already_loaded = True
|
|
|
159 |
else:
|
160 |
load_fine_tuned_model = True
|
161 |
else:
|
162 |
reload_detection_model = True
|
|
|
163 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
|
164 |
force_reload_full_model = True
|
165 |
|
|
|
188 |
st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
|
189 |
st.session_state['loading_in_progress'] = False
|
190 |
st.session_state['model_loaded'] = True
|
191 |
+
|
192 |
elif st.session_state.method == "In-Context Learning (n-shots)":
|
193 |
self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
|
194 |
st.session_state['loading_in_progress'] = False
|
|
|
195 |
|
196 |
if self.is_model_loaded():
|
|
|
197 |
free_gpu_resources()
|
198 |
st.session_state['loading_in_progress'] = False
|
199 |
self.image_qa_app(self.get_model())
|