m7mdal7aj commited on
Commit
7c3b785
1 Parent(s): b9ffc1e

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +39 -25
my_model/tabs/run_inference.py CHANGED
@@ -16,14 +16,36 @@ from my_model.config import inference_config as config
16
 
17
 
18
  class InferenceRunner(StateManager):
 
 
 
 
 
 
 
 
19
  def __init__(self):
 
 
 
20
 
21
  super().__init__()
22
  self.initialize_state()
23
- self.sample_images = config.SAMPLE_IMAGES
24
 
25
 
26
  def answer_question(self, caption, detected_objects_str, question, model):
 
 
 
 
 
 
 
 
 
 
 
 
27
  free_gpu_resources()
28
  answer = model.generate_answer(question, caption, detected_objects_str)
29
  free_gpu_resources()
@@ -31,10 +53,18 @@ class InferenceRunner(StateManager):
31
 
32
 
33
  def image_qa_app(self, kbvqa):
 
 
 
 
 
 
 
 
34
  # Display sample images as clickable thumbnails
35
  self.col1.write("Choose from sample images:")
36
- cols = self.col1.columns(len(self.sample_images))
37
- for idx, sample_image_path in enumerate(self.sample_images):
38
  with cols[idx]:
39
  image = Image.open(sample_image_path)
40
  image_for_display = self.resize_image(sample_image_path, 80, 80)
@@ -42,13 +72,8 @@ class InferenceRunner(StateManager):
42
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
43
  self.process_new_image(sample_image_path, image, kbvqa)
44
 
45
-
46
-
47
-
48
  # Image uploader
49
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
50
-
51
-
52
  if uploaded_image is not None:
53
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
54
 
@@ -67,7 +92,6 @@ class InferenceRunner(StateManager):
67
  caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
68
  self.update_image_data(image_key, caption, detected_objects_str, True)
69
  st.session_state['loading_in_progress'] = False
70
-
71
 
72
  # Initialize qa_history for each image
73
  qa_history = image_data.get('qa_history', [])
@@ -87,7 +111,6 @@ class InferenceRunner(StateManager):
87
  # Use the selected sample question or the custom question
88
  question = custom_question if selected_question == "Custom question..." else selected_question
89
 
90
-
91
  if not question:
92
  nested_col22.warning("Please select or enter a question.")
93
  else:
@@ -100,20 +123,19 @@ class InferenceRunner(StateManager):
100
  st.session_state['loading_in_progress'] = False
101
  self.add_to_qa_history(image_key, question, answer)
102
 
103
-
104
  # Display Q&A history for each image
105
  for num, (q, a) in enumerate(qa_history):
106
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\n")
107
 
108
- def display_message(self, message, warning=False, write=False, text=False):
109
- pass
110
-
111
-
112
 
113
  def run_inference(self):
 
 
 
 
 
114
 
115
  self.set_up_widgets()
116
-
117
  load_fine_tuned_model = False
118
  fine_tuned_model_already_loaded = False
119
  reload_detection_model = False
@@ -123,27 +145,21 @@ class InferenceRunner(StateManager):
123
  if st.session_state['settings_changed']:
124
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
125
 
126
-
127
  st.session_state.button_label = "Reload Model" if self.is_model_loaded() and self.settings_changed else "Load Model"
128
 
129
  with self.col1:
130
-
131
  if st.session_state.method == "Fine-Tuned Model":
132
-
133
  with st.container():
134
  nested_col11, nested_col12 = st.columns([0.5, 0.5])
135
  if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
136
-
137
  if st.session_state.button_label == "Load Model":
138
  if self.is_model_loaded():
139
  free_gpu_resources()
140
  fine_tuned_model_already_loaded = True
141
-
142
  else:
143
  load_fine_tuned_model = True
144
  else:
145
  reload_detection_model = True
146
-
147
  if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
148
  force_reload_full_model = True
149
 
@@ -172,14 +188,12 @@ class InferenceRunner(StateManager):
172
  st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
173
  st.session_state['loading_in_progress'] = False
174
  st.session_state['model_loaded'] = True
175
-
176
  elif st.session_state.method == "In-Context Learning (n-shots)":
177
  self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
178
  st.session_state['loading_in_progress'] = False
179
-
180
 
181
  if self.is_model_loaded():
182
-
183
  free_gpu_resources()
184
  st.session_state['loading_in_progress'] = False
185
  self.image_qa_app(self.get_model())
 
16
 
17
 
18
  class InferenceRunner(StateManager):
19
+
20
+ """
21
+ InferenceRunner manages the user interface and interactions for a Streamlit-based
22
+ Knowledge-Based Visual Question Answering (KBVQA) application. It handles image uploads,
23
+ displays sample images, and facilitates the question-answering process using the KBVQA model.
24
+ it inherits the StateManager class.
25
+ """
26
+
27
  def __init__(self):
28
+ """
29
+ Initializes the InferenceRunner instance, setting up the necessary state.
30
+ """
31
 
32
  super().__init__()
33
  self.initialize_state()
 
34
 
35
 
36
  def answer_question(self, caption, detected_objects_str, question, model):
37
+ """
38
+ Generates an answer to a given question based on the image's caption and detected objects.
39
+
40
+ Args:
41
+ caption (str): The caption generated for the image.
42
+ detected_objects_str (str): String representation of objects detected in the image.
43
+ question (str): The user's question about the image.
44
+ model (KBVQA): The loaded KBVQA model used for generating the answer.
45
+
46
+ Returns:
47
+ str: The generated answer to the question.
48
+ """
49
  free_gpu_resources()
50
  answer = model.generate_answer(question, caption, detected_objects_str)
51
  free_gpu_resources()
 
53
 
54
 
55
  def image_qa_app(self, kbvqa):
56
+ """
57
+ Main application interface for image-based question answering. It handles displaying
58
+ of sample images, uploading of new images, and facilitates the QA process.
59
+
60
+ Args:
61
+ kbvqa (KBVQA): The loaded KBVQA model used for image analysis and question answering.
62
+ """
63
+
64
  # Display sample images as clickable thumbnails
65
  self.col1.write("Choose from sample images:")
66
+ cols = self.col1.columns(len(config.SAMPLE_IMAGES))
67
+ for idx, sample_image_path in enumerate(config.SAMPLE_IMAGES):
68
  with cols[idx]:
69
  image = Image.open(sample_image_path)
70
  image_for_display = self.resize_image(sample_image_path, 80, 80)
 
72
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
73
  self.process_new_image(sample_image_path, image, kbvqa)
74
 
 
 
 
75
  # Image uploader
76
  uploaded_image = self.col1.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
 
 
77
  if uploaded_image is not None:
78
  self.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
79
 
 
92
  caption, detected_objects_str, image_with_boxes = self.analyze_image(image_data['image'], kbvqa)
93
  self.update_image_data(image_key, caption, detected_objects_str, True)
94
  st.session_state['loading_in_progress'] = False
 
95
 
96
  # Initialize qa_history for each image
97
  qa_history = image_data.get('qa_history', [])
 
111
  # Use the selected sample question or the custom question
112
  question = custom_question if selected_question == "Custom question..." else selected_question
113
 
 
114
  if not question:
115
  nested_col22.warning("Please select or enter a question.")
116
  else:
 
123
  st.session_state['loading_in_progress'] = False
124
  self.add_to_qa_history(image_key, question, answer)
125
 
 
126
  # Display Q&A history for each image
127
  for num, (q, a) in enumerate(qa_history):
128
  nested_col22.text(f"Q{num+1}: {q}\nA{num+1}: {a}\n")
129
 
 
 
 
 
130
 
131
  def run_inference(self):
132
+ """
133
+ Sets up the widgets and manages the inference process. This method handles model loading,
134
+ reloading, and the overall flow of the inference process based on user interactions.
135
+
136
+ """
137
 
138
  self.set_up_widgets()
 
139
  load_fine_tuned_model = False
140
  fine_tuned_model_already_loaded = False
141
  reload_detection_model = False
 
145
  if st.session_state['settings_changed']:
146
  self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
147
 
 
148
  st.session_state.button_label = "Reload Model" if self.is_model_loaded() and self.settings_changed else "Load Model"
149
 
150
  with self.col1:
 
151
  if st.session_state.method == "Fine-Tuned Model":
 
152
  with st.container():
153
  nested_col11, nested_col12 = st.columns([0.5, 0.5])
154
  if nested_col11.button(st.session_state.button_label, on_click=self.disable_widgets, disabled=self.is_widget_disabled):
 
155
  if st.session_state.button_label == "Load Model":
156
  if self.is_model_loaded():
157
  free_gpu_resources()
158
  fine_tuned_model_already_loaded = True
 
159
  else:
160
  load_fine_tuned_model = True
161
  else:
162
  reload_detection_model = True
 
163
  if nested_col12.button("Force Reload", on_click=self.disable_widgets, disabled=self.is_widget_disabled):
164
  force_reload_full_model = True
165
 
 
188
  st.session_state['time_taken_to_load_model'] = int(time.time()-t1)
189
  st.session_state['loading_in_progress'] = False
190
  st.session_state['model_loaded'] = True
191
+
192
  elif st.session_state.method == "In-Context Learning (n-shots)":
193
  self.col1.warning(f'Model using {st.session_state.method} is not deployed yet, will be ready later.')
194
  st.session_state['loading_in_progress'] = False
 
195
 
196
  if self.is_model_loaded():
 
197
  free_gpu_resources()
198
  st.session_state['loading_in_progress'] = False
199
  self.image_qa_app(self.get_model())