File size: 17,085 Bytes
18d1852 e46d486 69b926c 0a62769 18d1852 49e9e5b a812c7b 18d1852 11a17ef 6d4d5ac 18d1852 11a17ef 18d1852 12f08dc 4170a5f 74d450c a264403 96dd295 69b926c 1f66159 155a844 9becb2c a264403 9becb2c 2152f1f 2957e90 1b71503 cf8c147 929a4bd b40d522 1d68663 11a17ef b42fa65 18d1852 08655fb 11a17ef fcf48a1 11a17ef a5c03d9 3fdd1d7 d80fd56 bfdde42 0ed508e cf8c147 bfdde42 4e59653 bfdde42 4e59653 3fdd1d7 9becb2c 141a983 de78d1a d80fd56 f4bcc28 cf8c147 f4bcc28 18d1852 a5c03d9 18d1852 cf8c147 11a17ef 69b926c a5c03d9 18d1852 3fdd1d7 18d1852 cf8c147 3689b26 18d1852 e1f7f87 11c350b 3fdd1d7 18d1852 d0e9fe6 cf8c147 753c201 cc825df d9364fd f72214b a264403 12f08dc 753c201 2b3b1de 753c201 f72214b a812c7b ffdb10e 96dd295 ffdb10e a812c7b ffdb10e a812c7b ffdb10e 96dd295 a812c7b ffdb10e 96dd295 9becb2c da52f83 ffdb10e f72214b cf8c147 f72214b 87da9a2 753c201 3fdd1d7 18d1852 cf8c147 18d1852 9becb2c 18d1852 cf8c147 c70ac46 18d1852 3fdd1d7 cd3678b cf8c147 18d1852 9becb2c 6d688ae b4c744e 503215f b4c744e 6d688ae 503215f 6d688ae 155a844 503215f 1fc0405 9a8f19a 11a17ef 9becb2c 0ed508e 18d1852 d6f8382 09f5cd2 cf8c147 1a4044e cf8c147 18d1852 3fdd1d7 09f5cd2 cf8c147 1a4044e cf8c147 18d1852 d182243 1f66159 6227ce3 c6c0b0e 5c3c21e 18d1852 3fdd1d7 6a15fc4 cf8c147 18d1852 6a15fc4 18d1852 3fdd1d7 18d1852 cf8c147 18d1852 9becb2c 155a844 9becb2c 18d1852 afa0e9e f993077 2cd3e03 f993077 ba99375 436f3b4 afa0e9e 9becb2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 |
import pandas as pd
import copy
import time
from PIL import Image
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
class StateManager:
def __init__(self):
# Create three columns with different widths
self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
def initialize_state(self):
if 'images_data' not in st.session_state:
st.session_state['images_data'] = {}
if 'kbvqa' not in st.session_state:
st.session_state['kbvqa'] = None
if "button_label" not in st.session_state:
st.session_state['button_label'] = "Load Model"
if "previous_state" not in st.session_state:
st.session_state['previous_state'] = {}
if 'loading_in_progress' not in st.session_state:
st.session_state['loading_in_progress'] = False
if 'load_button_clicked' not in st.session_state:
st.session_state['load_button_clicked'] = False
if 'force_reload_button_clicked' not in st.session_state:
st.session_state['force_reload_button_clicked'] = False
if 'time_taken_to_load_model' not in st.session_state:
st.session_state['time_taken_to_load_model'] = None
if 'confidence_level_changed' not in st.session_state:
st.session_state['confidence_level_changed'] = False
if "settings_changed" not in st.session_state:
st.session_state['settings_changed'] = self.settings_changed
if 'model_loaded' not in st.session_state:
st.session_state['model_loaded'] = self.is_model_loaded
def set_up_widgets(self):
"""
Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
"""
self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method', disabled=self.is_widget_disabled)
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=0, key='detection_model', disabled=self.is_widget_disabled)
default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)
# Conditional display of model settings
show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
if show_model_settings:
self.display_model_settings
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
"""
Creates a slider widget with the specified parameters, optionally placing it in a specific column.
Args:
text (str): Text to display next to the slider.
min_value (float): Minimum value for the slider.
max_value (float): Maximum value for the slider.
value (float): Initial value for the slider.
step (float): Step size for the slider.
slider_key_name (str): Unique key for the slider.
col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
"""
if col is None:
return st.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabledd)
else:
return col.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=self.is_widget_disabled)
@property
def is_widget_disabled(self):
return st.session_state['loading_in_progress']
def disable_widgets(self):
st.session_state['loading_in_progress'] = True
@property
def settings_changed(self):
"""
Checks if any model settings have changed compared to the previous state.
Returns:
bool: True if any setting has changed, False otherwise.
"""
return self.has_state_changed()
@property
def display_model_settings(self):
"""
Displays a table of current model settings in the third column.
"""
self.col3.write("##### Current Model Settings:")
data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', 'loading_in_progress', 'model_loaded', 'time_taken_to_load_model' ]]
df = pd.DataFrame(data).reset_index(drop=True)
return self.col3.write(df)
def display_session_state(self):
"""
Displays a table of the complete application state..
"""
st.write("Current Model:")
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
df = pd.DataFrame(data).reset_index(drop=True)
st.write(df)
def load_model(self):
"""
Loads the KBVQA model based on the chosen method and settings.
- Frees GPU resources before loading.
- Calls `prepare_kbvqa_model` to create the model.
- Sets the detection confidence level on the model object.
- Updates previous state with current settings for change detection.
- Updates the button label to "Reload Model".
"""
try:
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model()
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
st.session_state['model_loaded'] = True
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
except Exception as e:
st.error(f"Error loading model: {e}")
def force_reload_model(self):
try:
self.delete_model()
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
st.session_state['model_loaded'] = True
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading model: {e}")
free_gpu_resources()
def delete_model(self):
"""
Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls `free_gpu_resources`.
"""
free_gpu_resources()
if self.is_model_loaded:
try:
del st.session_state['kbvqa']
free_gpu_resources()
except:
free_gpu_resources()
pass
# Function to check if any session state values have changed
def has_state_changed(self):
"""
Compares current session state with the previous state to identify changes.
Returns:
bool: True if any change is found, False otherwise.
"""
for key in st.session_state['previous_state']:
if st.session_state[key] != st.session_state['previous_state'][key]:
return True # Found a change
else: return False # No changes found
def get_model(self):
"""
Retrieve the KBVQA model from the session state.
Returns: KBVQA object: The loaded KBVQA model, or None if not loaded.
"""
return st.session_state.get('kbvqa', None)
@property
def is_model_loaded(self):
"""
Checks if the KBVQA model is loaded in the session state.
Returns:
bool: True if the model is loaded, False otherwise.
"""
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and st.session_state.kbvqa.all_models_loaded
def reload_detection_model(self):
"""
Reloads only the detection model of the KBVQA model with updated settings.
- Frees GPU resources before reloading.
- Checks if the model is already loaded.
- Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
- Updates detection confidence level on the model object.
- Displays a success message if model is reloaded successfully.
"""
try:
free_gpu_resources()
if self.is_model_loaded:
st.write(st.session_state['detection_model'] == st.session_state['previous_state']['detection_model'])
st.write(st.session_state['method'] == st.session_state['previous_state']['method'])
st.write(st.session_state['confidence_level'] != st.session_state['previous_state']['confidence_level'])
# check if only confidence is changed
if st.session_state['detection_model'] == st.session_state['previous_state']['detection_model']\
and st.session_state['method'] == st.session_state['previous_state']['method'] \
and st.session_state['confidence_level'] != st.session_state['previous_state']['confidence_level']:
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
st.write("reload func", st.session_state['kbvqa'].detection_confidence )
st.session_state['confidence_level_changed'] = True
return # only update the confidence level
prepare_kbvqa_model(only_reload_detection_model=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
self.col1.success("Model reloaded with updated settings and ready for inference.")
st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading detection model: {e}")
def process_new_image(self, image_key, image):
"""
Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session state.
This dictionary stores information about each processed image, including:
- `image`: The original image data.
- `caption`: Generated caption for the image.
- `detected_objects_str`: String representation of detected objects.
- `qa_history`: List of questions and answers related to the image.
- `analysis_done`: Flag indicating if analysis is complete.
Args:
image_key (str): Unique key for the image.
image (obj): The uploaded image data.
"""
if image_key not in st.session_state['images_data']:
st.session_state['images_data'][image_key] = {
'image': image,
'caption': '',
'detected_objects_str': '',
'qa_history': [],
'analysis_done': False
}
def analyze_image(self, image):
"""
Analyzes the image using the KBVQA model.
- Creates a copy of the image to avoid modifying the original.
- Displays a "Analyzing the image .." message.
- Calls KBVQA methods to generate a caption and detect objects.
- Returns the generated caption, detected objects string, and image with bounding boxes.
Args:
image (obj): The image data to analyze.
Returns:
tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
"""
img = copy.deepcopy(image)
st.text("Analyzing the image .. ")
caption = st.session_state['kbvqa'].get_caption(img)
st.write("analyze func", st.session_state['kbvqa'].detection_confidence)
image_with_boxes, detected_objects_str = st.session_state['kbvqa'].detect_objects(img)
return caption, detected_objects_str, image_with_boxes
def add_to_qa_history(self, image_key, question, answer, prompt_length):
"""
Adds a question-answer pair to the QA history of a specific image, to be used as hitory tracker.
Args:
image_key (str): Unique key for the image.
question (str): The question asked about the image.
answer (str): The answer generated by the KBVQA model.
"""
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))
def get_images_data(self):
"""
Returns the dictionary containing processed image data from the session state.
Returns:
dict: The dictionary storing information about processed images.
"""
return st.session_state['images_data']
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
"""
Updates the information stored for a specific image in the `images_data` dictionary in the application session state.
Args:
image_key (str): Unique key for the image.
caption (str): The generated caption for the image.
detected_objects_str (str): String representation of detected objects.
analysis_done (bool): Flag indicating if analysis of the image is complete.
"""
if image_key in st.session_state['images_data'] or st.session_state["confidence_level_changed"]:
st.session_state['images_data'][image_key].update({
'caption': caption,
'detected_objects_str': detected_objects_str,
'analysis_done': analysis_done
})
def resize_image(self, image_input, new_width=None, new_height=None):
"""
Resize an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
If both new_width and new_height are provided, the image is resized to those dimensions.
Args:
image (PIL.Image.Image): The image to resize.
new_width (int, optional): The target width of the image.
new_height (int, optional): The target height of the image.
Returns:
PIL.Image.Image: The resized image.
"""
img = copy.deepcopy(image_input)
if isinstance(img, str):
# Open the image from a file path
image = Image.open(img)
elif isinstance(img, Image.Image):
# Use the image directly if it's already a PIL Image object
image = img
else:
raise ValueError("image_input must be a file path or a PIL Image object")
if new_width is not None and new_height is None:
# Calculate new height to maintain aspect ratio
original_width, original_height = image.size
ratio = new_width / original_width
new_height = int(original_height * ratio)
elif new_width is None and new_height is not None:
# Calculate new width to maintain aspect ratio
original_width, original_height = image.size
ratio = new_height / original_height
new_width = int(original_width * ratio)
elif new_width is None and new_height is None:
raise ValueError("At least one of new_width or new_height must be provided")
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
def display_message(self, message, message_type):
if message_type == "warning":
st.warning(message)
elif message_type == "text":
st.text(message)
elif message_type == "success":
st.success(messae)
elif message_type == "write":
st.write(message)
else: st.error("Message type unknown")
|