m7mdal7aj commited on
Commit
1a4d79e
1 Parent(s): 212cd39

Update my_model/tabs/run_inference.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/run_inference.py +26 -72
my_model/tabs/run_inference.py CHANGED
@@ -11,8 +11,9 @@ from my_model.object_detection import detect_and_draw_objects
11
  from my_model.captioner.image_captioning import get_caption
12
  from my_model.gen_utilities import free_gpu_resources
13
  from my_model.KBVQA import KBVQA, prepare_kbvqa_model
 
14
 
15
-
16
 
17
  def answer_question(caption, detected_objects_str, question, model):
18
  free_gpu_resources()
@@ -39,10 +40,7 @@ def analyze_image(image, model):
39
  return caption, detected_objects_str, image_with_boxes
40
 
41
 
42
- def image_qa_app(kbvqa):
43
- if 'images_data' not in st.session_state:
44
- st.session_state['images_data'] = {}
45
-
46
  # Display sample images as clickable thumbnails
47
  st.write("Choose from sample images:")
48
  cols = st.columns(len(sample_images))
@@ -51,23 +49,21 @@ def image_qa_app(kbvqa):
51
  image = Image.open(sample_image_path)
52
  st.image(image, use_column_width=True)
53
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
54
- process_new_image(sample_image_path, image, kbvqa)
55
 
56
  # Image uploader
57
  uploaded_image = st.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
58
  if uploaded_image is not None:
59
- process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
60
 
61
  # Display and interact with each uploaded/selected image
62
- for image_key, image_data in st.session_state['images_data'].items():
63
  st.image(image_data['image'], caption=f'Uploaded Image: {image_key[-11:]}', use_column_width=True)
64
  if not image_data['analysis_done']:
65
  st.text("Cool image, please click 'Analyze Image'..")
66
  if st.button('Analyze Image', key=f'analyze_{image_key}'):
67
- caption, detected_objects_str, image_with_boxes = analyze_image(image_data['image'], kbvqa) # we can use the image_with_boxes later if we want to show it.
68
- image_data['caption'] = caption
69
- image_data['detected_objects_str'] = detected_objects_str
70
- image_data['analysis_done'] = True
71
 
72
  # Initialize qa_history for each image
73
  qa_history = image_data.get('qa_history', [])
@@ -77,16 +73,15 @@ def image_qa_app(kbvqa):
77
  if st.button('Get Answer', key=f'answer_{image_key}'):
78
  if question not in [q for q, _ in qa_history]:
79
  answer = answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
80
- qa_history.append((question, answer))
81
- image_data['qa_history'] = qa_history
82
- else:
83
- st.info("This question has already been asked.")
84
 
85
  # Display Q&A history for each image
86
  for q, a in qa_history:
87
  st.text(f"Q: {q}\nA: {a}\n")
88
 
89
 
 
 
90
  def process_new_image(image_key, image, kbvqa):
91
  """Process a new image and update the session state."""
92
  if image_key not in st.session_state['images_data']:
@@ -100,78 +95,37 @@ def process_new_image(image_key, image, kbvqa):
100
 
101
  def run_inference():
102
  st.title("Run Inference")
103
- st.write("Please note that this is not a general purpose model, it is specifically trained on OK-VQA dataset and is designed to give direct and short answers to the given questions.")
104
-
105
- method = st.selectbox(
106
- "Choose a method:",
107
- ["Fine-Tuned Model", "In-Context Learning (n-shots)"],
108
- index=0
109
- )
110
-
111
- detection_model = st.selectbox(
112
- "Choose a model for objects detection:",
113
- ["yolov5", "detic"],
114
- index=1 # "detic" is selected by default
115
- )
116
 
 
 
117
  default_confidence = 0.2 if detection_model == "yolov5" else 0.4
118
- confidence_level = st.slider(
119
- "Select minimum detection confidence level",
120
- min_value=0.1,
121
- max_value=0.9,
122
- value=default_confidence,
123
- step=0.1
124
- )
125
 
126
- if 'model_settings' not in st.session_state:
127
- st.session_state['model_settings'] = {'detection_model': detection_model, 'confidence_level': confidence_level}
128
-
129
- settings_changed = (st.session_state['model_settings']['detection_model'] != detection_model or
130
- st.session_state['model_settings']['confidence_level'] != confidence_level)
131
-
132
- need_model_reload = settings_changed and 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
133
-
134
- if need_model_reload:
135
- st.text("Model Settings have changed, please reload the model, this will take no time :)")
136
 
 
137
  button_label = "Reload Model" if need_model_reload else "Load Model"
138
 
139
  if method == "Fine-Tuned Model":
140
- if 'kbvqa' not in st.session_state:
141
- st.session_state['kbvqa'] = None
142
-
143
  if st.button(button_label):
144
-
145
- free_gpu_resources()
146
- if st.session_state['kbvqa'] is not None:
147
- if not settings_changed:
148
- st.write("Model already loaded.")
149
- else:
150
- free_gpu_resources()
151
- detection_model = st.session_state['model_settings']['detection_model']
152
- confidence_level = st.session_state['model_settings']['confidence_level']
153
- prepare_kbvqa_model(detection_model, only_reload_detection_model=True) # only reload detection model with new settings
154
- st.session_state['kbvqa'].detection_confidence = confidence_level
155
- free_gpu_resources()
156
  else:
157
- st.text("Loading the model will take no more than a few minutes . .")
158
- st.session_state['kbvqa'] = prepare_kbvqa_model(detection_model)
159
- st.session_state['kbvqa'].detection_confidence = confidence_level
160
- st.session_state['model_settings'] = {'detection_model': detection_model, 'confidence_level': confidence_level}
161
  st.write("Model is ready for inference.")
162
- free_gpu_resources()
163
-
164
-
165
 
166
- if st.session_state['kbvqa']:
167
- display_model_settings()
168
- display_session_state()
169
- image_qa_app(st.session_state['kbvqa'])
170
 
171
  else:
172
  st.write('Model is not ready yet, will be updated later.')
173
 
174
 
 
175
  def display_model_settings():
176
  st.write("### Current Model Settings:")
177
  st.table(pd.DataFrame(st.session_state['model_settings'], index=[0]))
 
11
  from my_model.captioner.image_captioning import get_caption
12
  from my_model.gen_utilities import free_gpu_resources
13
  from my_model.KBVQA import KBVQA, prepare_kbvqa_model
14
+ from my_model.utilities.st_utils import UIManager, StateManager
15
 
16
+ state_manager = StateManager()
17
 
18
  def answer_question(caption, detected_objects_str, question, model):
19
  free_gpu_resources()
 
40
  return caption, detected_objects_str, image_with_boxes
41
 
42
 
43
+ def image_qa_app(state_manager, kbvqa):
 
 
 
44
  # Display sample images as clickable thumbnails
45
  st.write("Choose from sample images:")
46
  cols = st.columns(len(sample_images))
 
49
  image = Image.open(sample_image_path)
50
  st.image(image, use_column_width=True)
51
  if st.button(f'Select Sample Image {idx + 1}', key=f'sample_{idx}'):
52
+ state_manager.process_new_image(sample_image_path, image, kbvqa)
53
 
54
  # Image uploader
55
  uploaded_image = st.file_uploader("Or upload an Image", type=["png", "jpg", "jpeg"])
56
  if uploaded_image is not None:
57
+ state_manager.process_new_image(uploaded_image.name, Image.open(uploaded_image), kbvqa)
58
 
59
  # Display and interact with each uploaded/selected image
60
+ for image_key, image_data in state_manager.get_images_data().items():
61
  st.image(image_data['image'], caption=f'Uploaded Image: {image_key[-11:]}', use_column_width=True)
62
  if not image_data['analysis_done']:
63
  st.text("Cool image, please click 'Analyze Image'..")
64
  if st.button('Analyze Image', key=f'analyze_{image_key}'):
65
+ caption, detected_objects_str, image_with_boxes = state_manager.analyze_image(image_data['image'], kbvqa)
66
+ state_manager.update_image_data(image_key, caption, detected_objects_str, True)
 
 
67
 
68
  # Initialize qa_history for each image
69
  qa_history = image_data.get('qa_history', [])
 
73
  if st.button('Get Answer', key=f'answer_{image_key}'):
74
  if question not in [q for q, _ in qa_history]:
75
  answer = answer_question(image_data['caption'], image_data['detected_objects_str'], question, kbvqa)
76
+ state_manager.add_to_qa_history(image_key, question, answer)
 
 
 
77
 
78
  # Display Q&A history for each image
79
  for q, a in qa_history:
80
  st.text(f"Q: {q}\nA: {a}\n")
81
 
82
 
83
+
84
+
85
  def process_new_image(image_key, image, kbvqa):
86
  """Process a new image and update the session state."""
87
  if image_key not in st.session_state['images_data']:
 
95
 
96
  def run_inference():
97
  st.title("Run Inference")
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ method = st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0)
100
+ detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1)
101
  default_confidence = 0.2 if detection_model == "yolov5" else 0.4
102
+ confidence_level = st.slider("Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1)
 
 
 
 
 
 
103
 
104
+ state_manager.update_model_settings(detection_model, confidence_level, method)
105
+ settings_changed = state_manager.check_settings_changed(method, detection_model, confidence_level)
 
 
 
 
 
 
 
 
106
 
107
+ need_model_reload = settings_changed and state_manager.is_model_loaded()
108
  button_label = "Reload Model" if need_model_reload else "Load Model"
109
 
110
  if method == "Fine-Tuned Model":
 
 
 
111
  if st.button(button_label):
112
+ if state_manager.is_model_loaded() and not settings_changed:
113
+ st.write("Model already loaded.")
 
 
 
 
 
 
 
 
 
 
114
  else:
115
+ st.text("Loading the model, please wait...")
116
+ state_manager.load_model(detection_model, confidence_level)
 
 
117
  st.write("Model is ready for inference.")
 
 
 
118
 
119
+ if state_manager.is_model_loaded():
120
+ state_manager.display_model_settings()
121
+ state_manager.display_session_state()
122
+ image_qa_app(state_manager.get_model())
123
 
124
  else:
125
  st.write('Model is not ready yet, will be updated later.')
126
 
127
 
128
+
129
  def display_model_settings():
130
  st.write("### Current Model Settings:")
131
  st.table(pd.DataFrame(st.session_state['model_settings'], index=[0]))