File size: 20,021 Bytes
7138ab3 8286a8f 7138ab3 18d1852 e46d486 69b926c 0a62769 6740cd3 18d1852 49e9e5b a812c7b 18d1852 11a17ef 7138ab3 8286a8f 6d4d5ac 8286a8f 7138ab3 8286a8f 7138ab3 8286a8f 7138ab3 18d1852 7138ab3 eaa0bad 8286a8f 11a17ef 8286a8f 11a17ef 7138ab3 eaa0bad 8286a8f ccdbda8 18d1852 8286a8f 12f08dc a264403 96dd295 69b926c 1f66159 9becb2c ccdbda8 eaa0bad cf8c147 7138ab3 cf8c147 929a4bd 8286a8f 11a17ef 8286a8f 18d1852 08655fb 8286a8f 11a17ef a5c03d9 3fdd1d7 8286a8f cf8c147 7138ab3 cf8c147 8286a8f bfdde42 8286a8f bfdde42 8286a8f 3fdd1d7 141a983 7138ab3 8286a8f 7138ab3 8286a8f 141a983 de78d1a 7138ab3 eaa0bad 8286a8f 7138ab3 eaa0bad 8286a8f de78d1a eaa0bad f4bcc28 7138ab3 cf8c147 8286a8f cf8c147 8286a8f f4bcc28 eaa0bad 8624b37 7138ab3 cf8c147 eaa0bad cf8c147 8286a8f eaa0bad 18d1852 7138ab3 cf8c147 eaa0bad 7138ab3 cf8c147 8286a8f 2250430 eaa0bad cf8c147 8286a8f cf8c147 7138ab3 cf8c147 8286a8f 753c201 cc825df d9364fd f72214b 2250430 a264403 12f08dc 753c201 e2de402 8286a8f 753c201 f72214b eaa0bad 8286a8f eaa0bad 7138ab3 eaa0bad 8286a8f ffdb10e 96dd295 ffdb10e a812c7b ffdb10e 2250430 8286a8f ffdb10e a812c7b ffdb10e 96dd295 a812c7b 7138ab3 eaa0bad 96dd295 eaa0bad 7138ab3 96dd295 8286a8f 96dd295 8286a8f 9becb2c da52f83 e2de402 da52f83 e2de402 da52f83 8286a8f eaa0bad cf8c147 8286a8f cf8c147 f72214b 2250430 7e798e5 f72214b 8286a8f 753c201 453b185 cf8c147 7138ab3 8286a8f 7138ab3 cf8c147 8286a8f 18d1852 9becb2c eaa0bad cf8c147 8286a8f cf8c147 8286a8f 0675d16 8286a8f 18d1852 eaa0bad cf8c147 8286a8f cf8c147 7138ab3 cf8c147 8286a8f 18d1852 9becb2c 1fc0405 9a8f19a 11a17ef 2250430 9becb2c 0ed508e 18d1852 eaa0bad cf8c147 8286a8f cf8c147 8286a8f cf8c147 7138ab3 cf8c147 8286a8f 18d1852 eaa0bad cf8c147 8286a8f cf8c147 1a4044e cf8c147 e2de402 18d1852 1f66159 c6c0b0e 87db14a 18d1852 eaa0bad cf8c147 7138ab3 8286a8f cf8c147 7138ab3 cf8c147 8286a8f 18d1852 6a15fc4 18d1852 7138ab3 cf8c147 8286a8f cf8c147 8286a8f 18d1852 8286a8f 7138ab3 9becb2c 8286a8f 9becb2c 7138ab3 9becb2c 726be01 9becb2c 18d1852 7138ab3 f993077 7138ab3 f993077 8286a8f f993077 7138ab3 8286a8f f993077 7138ab3 f993077 8286a8f f993077 2cd3e03 8286a8f f993077 8286a8f f993077 ba99375 7138ab3 8286a8f 7138ab3 8286a8f 9becb2c 8286a8f 9becb2c 8286a8f 8d3df8c 7138ab3 eaa0bad 7138ab3 eaa0bad 8286a8f eaa0bad 7138ab3 eaa0bad 7138ab3 8286a8f 7138ab3 eaa0bad 8286a8f 26f093e eaa0bad 26f093e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 |
# This module contains the StateManager class.
# The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run,
# and test the models.
import pandas as pd
import copy
import time
from PIL import Image
from typing import Tuple, Dict, Optional
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model
class StateManager:
"""
Manages the user interface and session state for the Streamlit-based Knowledge-Based Visual Question Answering
(KBVQA) application.
This class includes methods to initialize the session state, set up various UI widgets for model selection and
settings,
manage the loading and reloading of the KBVQA model, and handle the processing and analysis of images.
It tracks changes to the application's state to ensure the correct configuration is maintained.
Additionally, it provides methods to display the current model settings and the complete application state within
the Streamlit interface.
The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run,
and test the models.
Attributes:
col1 (streamlit.columns): The first column in the Streamlit layout.
col2 (streamlit.columns): The second column in the Streamlit layout.
col3 (streamlit.columns): The third column in the Streamlit layout.
"""
def __init__(self) -> None:
"""
Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
"""
# Create three columns with different widths
self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])
def initialize_state(self) -> None:
"""
Initializes the Streamlit session state with default values for various keys.
"""
if "previous_state" not in st.session_state:
st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
if 'images_data' not in st.session_state:
st.session_state['images_data'] = {}
if 'kbvqa' not in st.session_state:
st.session_state['kbvqa'] = None
if "button_label" not in st.session_state:
st.session_state['button_label'] = "Load Model"
if 'loading_in_progress' not in st.session_state:
st.session_state['loading_in_progress'] = False
if 'load_button_clicked' not in st.session_state:
st.session_state['load_button_clicked'] = False
if 'force_reload_button_clicked' not in st.session_state:
st.session_state['force_reload_button_clicked'] = False
if 'time_taken_to_load_model' not in st.session_state:
st.session_state['time_taken_to_load_model'] = None
if "settings_changed" not in st.session_state:
st.session_state['settings_changed'] = self.settings_changed
if 'model_loaded' not in st.session_state:
st.session_state['model_loaded'] = self.is_model_loaded
def set_up_widgets(self) -> None:
"""
Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
Returns:
None
"""
self.col1.selectbox("Choose a model:",
["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Vision-Language Embeddings Alignment"],
index=1, key='method', disabled=self.is_widget_disabled)
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1,
key='detection_model', disabled=self.is_widget_disabled)
default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9,
value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)
# Conditional display of model settings
show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
if show_model_settings:
self.display_model_settings
def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float,
slider_key_name: str, col=None) -> None:
"""
Creates a slider widget with the specified parameters, optionally placing it in a specific column.
Args:
text (str): Text to display next to the slider.
min_value (float): Minimum value for the slider.
max_value (float): Maximum value for the slider.
value (float): Initial value for the slider.
step (float): Step size for the slider.
slider_key_name (str): Unique key for the slider.
col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).
Returns:
None
"""
if col is None:
return st.slider(text, min_value, max_value, value, step, key=slider_key_name,
disabled=self.is_widget_disabledd)
else:
return col.slider(text, min_value, max_value, value, step, key=slider_key_name,
disabled=self.is_widget_disabled)
@property
def is_widget_disabled(self) -> bool:
"""
Checks if widgets should be disabled based on the 'loading_in_progress' state.
Returns:
bool: True if widgets should be disabled, False otherwise.
"""
return st.session_state['loading_in_progress']
def disable_widgets(self) -> None:
"""
Disables widgets by setting the 'loading_in_progress' state to True.
Returns:
None
"""
st.session_state['loading_in_progress'] = True
@property
def settings_changed(self) -> bool:
"""
Checks if any model settings have changed compared to the previous state.
Returns:
bool: True if any setting has changed, False otherwise.
"""
return self.has_state_changed()
@property
def confidance_change(self) -> bool:
"""
Checks if the confidence level setting has changed compared to the previous state.
Returns:
bool: True if the confidence level has changed, False otherwise.
"""
return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]
def update_prev_state(self) -> None:
"""
Updates the 'previous_state' in the session state with the current state values.
Returns:
None
"""
for key in st.session_state['previous_state']:
st.session_state['previous_state'][key] = st.session_state[key]
def load_model(self) -> None:
"""
Loads the KBVQA model based on the chosen method and settings.
- Frees GPU resources before loading.
- Calls `prepare_kbvqa_model` to create the model.
- Sets the detection confidence level on the model object.
- Updates previous state with current settings for change detection.
- Updates the button label to "Reload Model".
Returns:
None
"""
try:
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model()
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
self.update_prev_state()
st.session_state['model_loaded'] = True
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
free_gpu_resources()
except Exception as e:
st.error(f"Error loading model: {e}")
def force_reload_model(self) -> None:
"""
Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls
`free_gpu_resources`.
- Deletes the current KBVQA model from the session state.
- Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
- Updates the detection confidence level on the model object.
- Displays a success message if the model is reloaded successfully.
Returns:
None
"""
try:
self.delete_model()
free_gpu_resources()
st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
# Update the previous state with current session state values
self.update_prev_state()
st.session_state['model_loaded'] = True
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading model: {e}")
free_gpu_resources()
def delete_model(self) -> None:
"""
This method deletes the current models and calls `free_gpu_resources`.
Returns:
None
"""
free_gpu_resources()
if self.is_model_loaded:
try:
del st.session_state['kbvqa']
free_gpu_resources()
free_gpu_resources()
except:
free_gpu_resources()
free_gpu_resources()
pass
def has_state_changed(self) -> bool:
"""
Compares current session state with the previous state to identify changes.
Returns:
bool: True if any change is found, False otherwise.
"""
for key in st.session_state['previous_state']:
if key == 'confidence_level':
continue # confidence_level tracker is separate
if key in st.session_state and st.session_state[key] != st.session_state['previous_state'][key]:
return True # Found a change
else:
return False # No changes found
def get_model(self) -> KBVQA:
"""
Retrieves the KBVQA model from the session state.
Returns:
KBVQA: The loaded KBVQA model, or None if not loaded.
"""
return st.session_state.get('kbvqa', None)
@property
def is_model_loaded(self) -> bool:
"""
Checks if the KBVQA model is loaded in the session state.
Returns:
bool: True if the model is loaded, False otherwise.
"""
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and \
st.session_state.kbvqa.all_models_loaded \
and (st.session_state['previous_state']['method'] is not None
and st.session_state['method'] == st.session_state['previous_state']['method'])
def reload_detection_model(self) -> None:
"""
Reloads only the detection model of the KBVQA model with updated settings.
- Frees GPU resources before reloading.
- Checks if the model is already loaded.
- Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
- Updates detection confidence level on the model object.
- Displays a success message if model is reloaded successfully.
Returns:
None
"""
try:
free_gpu_resources()
if self.is_model_loaded:
prepare_kbvqa_model(only_reload_detection_model=True)
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
self.col1.success("Model reloaded with updated settings and ready for inference.")
self.update_prev_state
st.session_state['button_label'] = "Reload Model"
free_gpu_resources()
except Exception as e:
st.error(f"Error reloading detection model: {e}")
def process_new_image(self, image_key: str, image) -> None:
"""
Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session
state.
This dictionary stores information about each processed image, including:
- `image`: The original image data.
- `caption`: Generated caption for the image.
- `detected_objects_str`: String representation of detected objects.
- `qa_history`: List of questions and answers related to the image.
- `analysis_done`: Flag indicating if analysis is complete.
Args:
image_key (str): Unique key for the image.
image (obj): The uploaded image data.
Returns:
None
"""
if image_key not in st.session_state['images_data']:
st.session_state['images_data'][image_key] = {
'image': image,
'caption': '',
'detected_objects_str': '',
'qa_history': [],
'analysis_done': False
}
def analyze_image(self, image) -> Tuple[str, str, object]:
"""
Analyzes the image using the KBVQA model.
- Creates a copy of the image to avoid modifying the original.
- Displays a "Analyzing the image .." message.
- Calls KBVQA methods to generate a caption and detect objects.
- Returns the generated caption, detected objects string, and image with bounding boxes.
Args:
image (obj): The image data to analyze.
Returns:
tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
"""
free_gpu_resources()
free_gpu_resources()
img = copy.deepcopy(image)
caption = st.session_state['kbvqa'].get_caption(img)
image_with_boxes, detected_objects_str = st.session_state['kbvqa'].detect_objects(img)
free_gpu_resources()
return caption, detected_objects_str, image_with_boxes
def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
"""
Adds a question-answer pair to the QA history of a specific image, to be used as a history tracker.
Args:
image_key (str): Unique key for the image.
question (str): The question asked about the image.
answer (str): The answer generated by the KBVQA model.
prompt_length (int): The length of the prompt used for generating the answer.
Returns:
None
"""
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))
def get_images_data(self) -> Dict:
"""
Returns the dictionary containing processed image data from the session state.
Returns:
dict: The dictionary storing information about processed images.
"""
return st.session_state['images_data']
def update_image_data(self, image_key: str, caption: str, detected_objects_str: str, analysis_done: bool) -> None:
"""
Updates the information stored for a specific image in the `images_data` dictionary in the application session
state.
Args:
image_key (str): Unique key for the image.
caption (str): The generated caption for the image.
detected_objects_str (str): String representation of detected objects.
analysis_done (bool): Flag indicating if analysis of the image is complete.
Returns:
None
"""
if image_key in st.session_state['images_data']:
st.session_state['images_data'][image_key].update({
'caption': caption,
'detected_objects_str': detected_objects_str,
'analysis_done': analysis_done
})
def resize_image(self, image_input, new_width: Optional[int] = None, new_height: Optional[int] = None) -> Image:
"""
Resizes an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
If both new_width and new_height are provided, the image is resized to those dimensions.
Args:
image_input (PIL.Image.Image): The image to resize.
new_width (int, optional): The target width of the image.
new_height (int, optional): The target height of the image.
Returns:
PIL.Image.Image: The resized image.
"""
img = copy.deepcopy(image_input)
if isinstance(img, str):
# Open the image from a file path
image = Image.open(img)
elif isinstance(img, Image.Image):
# Use the image directly if it's already a PIL Image object
image = img
else:
raise ValueError("image_input must be a file path or a PIL Image object")
if new_width is not None and new_height is None:
# Calculate new height to maintain aspect ratio
original_width, original_height = image.size
ratio = new_width / original_width
new_height = int(original_height * ratio)
elif new_width is None and new_height is not None:
# Calculate new width to maintain aspect ratio
original_width, original_height = image.size
ratio = new_height / original_height
new_width = int(original_width * ratio)
elif new_width is None and new_height is None:
raise ValueError("At least one of new_width or new_height must be provided")
# Resize the image
resized_image = image.resize((new_width, new_height))
return resized_image
def display_message(self, message: str, message_type: str) -> None:
"""
Displays a message in the Streamlit interface based on the specified message type.
Args:
message (str): The message to display.
message_type (str): The type of message ('warning', 'text', 'success', 'write', or 'error').
Returns:
None
"""
if message_type == "warning":
st.warning(message)
elif message_type == "text":
st.text(message)
elif message_type == "success":
st.success(message)
elif message_type == "write":
st.write(message)
else:
st.error("Message type unknown")
@property
def display_model_settings(self) -> None:
"""
Displays a table of current model settings in the third column.
Returns:
None
"""
self.col3.write("##### Current Model Settings:")
data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if
key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed',
'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data']]
df = pd.DataFrame(data).reset_index(drop=True)
return self.col3.write(df)
def display_session_state(self, col) -> None:
"""
Displays a table of the complete application state in the specified column.
Args:
col (streamlit.columns.Column): The Streamlit column to display the session state.
Returns:
None
"""
col.write("Current Model:")
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
df = pd.DataFrame(data).reset_index(drop=True)
col.write(df)
|