File size: 20,021 Bytes
7138ab3
8286a8f
 
7138ab3
 
18d1852
e46d486
69b926c
0a62769
6740cd3
18d1852
49e9e5b
a812c7b
18d1852
11a17ef
7138ab3
 
8286a8f
 
6d4d5ac
8286a8f
 
7138ab3
 
8286a8f
 
7138ab3
8286a8f
 
7138ab3
 
 
 
 
 
18d1852
7138ab3
eaa0bad
 
 
8286a8f
11a17ef
8286a8f
11a17ef
7138ab3
eaa0bad
 
 
8286a8f
ccdbda8
 
18d1852
 
 
 
8286a8f
12f08dc
a264403
 
96dd295
 
 
 
69b926c
1f66159
9becb2c
 
 
 
ccdbda8
eaa0bad
cf8c147
 
7138ab3
 
 
cf8c147
929a4bd
8286a8f
 
 
 
 
11a17ef
8286a8f
 
18d1852
08655fb
8286a8f
11a17ef
a5c03d9
3fdd1d7
8286a8f
 
cf8c147
 
 
 
 
 
 
 
 
 
 
7138ab3
 
 
cf8c147
8286a8f
bfdde42
8286a8f
 
bfdde42
8286a8f
 
3fdd1d7
141a983
7138ab3
 
 
8286a8f
7138ab3
 
 
8286a8f
141a983
de78d1a
7138ab3
eaa0bad
 
8286a8f
7138ab3
 
eaa0bad
8286a8f
de78d1a
eaa0bad
f4bcc28
7138ab3
cf8c147
 
8286a8f
cf8c147
 
 
8286a8f
f4bcc28
eaa0bad
8624b37
7138ab3
cf8c147
eaa0bad
 
 
 
cf8c147
8286a8f
eaa0bad
18d1852
7138ab3
cf8c147
eaa0bad
7138ab3
 
 
cf8c147
8286a8f
2250430
 
eaa0bad
 
cf8c147
 
8286a8f
cf8c147
 
 
 
 
7138ab3
 
 
cf8c147
8286a8f
753c201
 
cc825df
d9364fd
f72214b
2250430
a264403
12f08dc
753c201
e2de402
8286a8f
753c201
f72214b
 
eaa0bad
 
8286a8f
 
eaa0bad
 
 
 
 
7138ab3
 
 
eaa0bad
8286a8f
ffdb10e
96dd295
ffdb10e
a812c7b
ffdb10e
 
2250430
8286a8f
ffdb10e
a812c7b
ffdb10e
96dd295
a812c7b
7138ab3
eaa0bad
96dd295
eaa0bad
7138ab3
 
 
96dd295
8286a8f
96dd295
8286a8f
9becb2c
da52f83
 
 
e2de402
da52f83
e2de402
da52f83
 
8286a8f
eaa0bad
cf8c147
 
8286a8f
cf8c147
 
 
f72214b
2250430
 
7e798e5
f72214b
8286a8f
 
753c201
453b185
cf8c147
7138ab3
8286a8f
7138ab3
 
cf8c147
8286a8f
18d1852
 
9becb2c
eaa0bad
cf8c147
 
8286a8f
cf8c147
 
 
8286a8f
0675d16
8286a8f
 
 
18d1852
eaa0bad
cf8c147
 
8286a8f
cf8c147
 
 
 
 
7138ab3
 
 
cf8c147
8286a8f
18d1852
 
9becb2c
1fc0405
9a8f19a
11a17ef
2250430
9becb2c
0ed508e
18d1852
 
 
 
eaa0bad
cf8c147
8286a8f
 
 
cf8c147
 
 
 
 
 
8286a8f
cf8c147
 
 
7138ab3
 
 
cf8c147
8286a8f
18d1852
 
 
 
 
 
 
 
 
eaa0bad
cf8c147
 
 
 
 
 
 
8286a8f
cf8c147
 
1a4044e
cf8c147
 
 
e2de402
 
 
18d1852
1f66159
c6c0b0e
87db14a
18d1852
 
eaa0bad
cf8c147
7138ab3
8286a8f
cf8c147
 
 
 
7138ab3
 
 
 
cf8c147
8286a8f
18d1852
6a15fc4
18d1852
7138ab3
cf8c147
 
8286a8f
cf8c147
 
 
8286a8f
18d1852
8286a8f
7138ab3
9becb2c
8286a8f
 
 
9becb2c
 
 
 
 
7138ab3
 
 
9becb2c
726be01
9becb2c
 
 
 
 
18d1852
7138ab3
f993077
7138ab3
f993077
8286a8f
f993077
7138ab3
 
 
8286a8f
f993077
7138ab3
f993077
8286a8f
f993077
2cd3e03
 
 
 
 
 
 
 
8286a8f
f993077
 
 
 
 
 
 
 
 
 
 
 
8286a8f
f993077
 
 
ba99375
7138ab3
 
 
8286a8f
7138ab3
 
 
 
 
 
 
8286a8f
9becb2c
 
 
 
 
8286a8f
9becb2c
 
8286a8f
 
 
8d3df8c
7138ab3
eaa0bad
 
7138ab3
 
 
eaa0bad
 
8286a8f
 
 
eaa0bad
 
 
7138ab3
eaa0bad
7138ab3
8286a8f
7138ab3
 
 
 
 
eaa0bad
8286a8f
26f093e
eaa0bad
 
26f093e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
# This module contains the StateManager class.
# The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run, 
# and test the models.


import pandas as pd
import copy
import time
from PIL import Image
from typing import Tuple, Dict, Optional
import streamlit as st
from my_model.utilities.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model


class StateManager:
    """
    Manages the user interface and session state for the Streamlit-based Knowledge-Based Visual Question Answering 
    (KBVQA) application.

    This class includes methods to initialize the session state, set up various UI widgets for model selection and 
    settings, 
    manage the loading and reloading of the KBVQA model, and handle the processing and analysis of images. 
    It tracks changes to the application's state to ensure the correct configuration is maintained. 
    Additionally, it provides methods to display the current model settings and the complete application state within 
    the Streamlit interface.

    The StateManager class is primarily designed to facilitate the Run Inference tool that allows users to load, run, 
    and test the models.

    Attributes:
        col1 (streamlit.columns): The first column in the Streamlit layout.
        col2 (streamlit.columns): The second column in the Streamlit layout.
        col3 (streamlit.columns): The third column in the Streamlit layout.
    """

    def __init__(self) -> None:
        """
        Initializes the StateManager instance, setting up the Streamlit columns for the user interface.
        """

        # Create three columns with different widths
        self.col1, self.col2, self.col3 = st.columns([0.2, 0.6, 0.2])

    def initialize_state(self) -> None:
        """
        Initializes the Streamlit session state with default values for various keys.
        """

        if "previous_state" not in st.session_state:
            st.session_state['previous_state'] = {'method': None, 'detection_model': None, 'confidence_level': None}
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if "button_label" not in st.session_state:
            st.session_state['button_label'] = "Load Model"
        if 'loading_in_progress' not in st.session_state:
            st.session_state['loading_in_progress'] = False
        if 'load_button_clicked' not in st.session_state:
            st.session_state['load_button_clicked'] = False
        if 'force_reload_button_clicked' not in st.session_state:
            st.session_state['force_reload_button_clicked'] = False
        if 'time_taken_to_load_model' not in st.session_state:
            st.session_state['time_taken_to_load_model'] = None
        if "settings_changed" not in st.session_state:
            st.session_state['settings_changed'] = self.settings_changed
        if 'model_loaded' not in st.session_state:
            st.session_state['model_loaded'] = self.is_model_loaded

    def set_up_widgets(self) -> None:
        """
        Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.

        Returns:
            None
        """

        self.col1.selectbox("Choose a model:",
                            ["13b-Fine-Tuned Model", "7b-Fine-Tuned Model", "Vision-Language Embeddings Alignment"],
                            index=1, key='method', disabled=self.is_widget_disabled)
        detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1,
                                              key='detection_model', disabled=self.is_widget_disabled)
        default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
        self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9,
                              value=default_confidence, step=0.05, slider_key_name='confidence_level', col=self.col1)

        # Conditional display of model settings
        show_model_settings = self.col3.checkbox("Show Model Settings", True, disabled=self.is_widget_disabled)
        if show_model_settings:
            self.display_model_settings

    def set_slider_value(self, text: str, min_value: float, max_value: float, value: float, step: float,
                         slider_key_name: str, col=None) -> None:
        """
        Creates a slider widget with the specified parameters, optionally placing it in a specific column.

        Args:
            text (str): Text to display next to the slider.
            min_value (float): Minimum value for the slider.
            max_value (float): Maximum value for the slider.
            value (float): Initial value for the slider.
            step (float): Step size for the slider.
            slider_key_name (str): Unique key for the slider.
            col (streamlit.columns.Column, optional): Column to place the slider in. Defaults to None (displayed in main area).

        Returns:
            None
        """

        if col is None:
            return st.slider(text, min_value, max_value, value, step, key=slider_key_name,
                             disabled=self.is_widget_disabledd)
        else:
            return col.slider(text, min_value, max_value, value, step, key=slider_key_name,
                              disabled=self.is_widget_disabled)

    @property
    def is_widget_disabled(self) -> bool:
        """
        Checks if widgets should be disabled based on the 'loading_in_progress' state.

        Returns:
            bool: True if widgets should be disabled, False otherwise.
        """

        return st.session_state['loading_in_progress']

    def disable_widgets(self) -> None:
        """
        Disables widgets by setting the 'loading_in_progress' state to True.

        Returns:
            None
        """

        st.session_state['loading_in_progress'] = True

    @property
    def settings_changed(self) -> bool:
        """
        Checks if any model settings have changed compared to the previous state.

        Returns:
            bool: True if any setting has changed, False otherwise.
        """

        return self.has_state_changed()

    @property
    def confidance_change(self) -> bool:
        """
        Checks if the confidence level setting has changed compared to the previous state.

        Returns:
            bool: True if the confidence level has changed, False otherwise.
        """

        return st.session_state["confidence_level"] != st.session_state["previous_state"]["confidence_level"]

    def update_prev_state(self) -> None:
        """
        Updates the 'previous_state' in the session state with the current state values.

        Returns:
            None
        """

        for key in st.session_state['previous_state']:
            st.session_state['previous_state'][key] = st.session_state[key]

    def load_model(self) -> None:
        """
        Loads the KBVQA model based on the chosen method and settings.

        - Frees GPU resources before loading.
        - Calls `prepare_kbvqa_model` to create the model.
        - Sets the detection confidence level on the model object.
        - Updates previous state with current settings for change detection.
        - Updates the button label to "Reload Model".

        Returns:
            None
        """

        try:
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            # Update the previous state with current session state values
            self.update_prev_state()
            st.session_state['model_loaded'] = True
            st.session_state['button_label'] = "Reload Model"
            free_gpu_resources()
            free_gpu_resources()

        except Exception as e:
            st.error(f"Error loading model: {e}")

    def force_reload_model(self) -> None:
        """
        Forces a reload of all models, freeing up GPU resources. This method deletes the current models and calls 
        `free_gpu_resources`.

        - Deletes the current KBVQA model from the session state.
        - Calls `prepare_kbvqa_model` with `force_reload=True` to reload the model.
        - Updates the detection confidence level on the model object.
        - Displays a success message if the model is reloaded successfully.

        Returns:
            None
        """

        try:
            self.delete_model()
            free_gpu_resources()
            st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            # Update the previous state with current session state values
            self.update_prev_state()

            st.session_state['model_loaded'] = True
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading model: {e}")
            free_gpu_resources()

    def delete_model(self) -> None:
        """
        This method deletes the current models and calls `free_gpu_resources`.

        Returns:
            None
        """

        free_gpu_resources()

        if self.is_model_loaded:
            try:
                del st.session_state['kbvqa']
                free_gpu_resources()
                free_gpu_resources()
            except:
                free_gpu_resources()
                free_gpu_resources()
                pass

    def has_state_changed(self) -> bool:
        """
        Compares current session state with the previous state to identify changes.

        Returns:
            bool: True if any change is found, False otherwise.
        """
        for key in st.session_state['previous_state']:
            if key == 'confidence_level':
                continue  # confidence_level tracker is separate
            if key in st.session_state and st.session_state[key] != st.session_state['previous_state'][key]:
                return True  # Found a change
        else:
            return False  # No changes found   

    def get_model(self) -> KBVQA:
        """
        Retrieves the KBVQA model from the session state.

        Returns:
            KBVQA: The loaded KBVQA model, or None if not loaded.
        """

        return st.session_state.get('kbvqa', None)

    @property
    def is_model_loaded(self) -> bool:
        """
        Checks if the KBVQA model is loaded in the session state.

        Returns:
            bool: True if the model is loaded, False otherwise.
        """

        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None and \
               st.session_state.kbvqa.all_models_loaded \
               and (st.session_state['previous_state']['method'] is not None
                    and st.session_state['method'] == st.session_state['previous_state']['method'])

    def reload_detection_model(self) -> None:
        """
        Reloads only the detection model of the KBVQA model with updated settings.

        - Frees GPU resources before reloading.
        - Checks if the model is already loaded.
        - Calls `prepare_kbvqa_model` with `only_reload_detection_model=True`.
        - Updates detection confidence level on the model object.
        - Displays a success message if model is reloaded successfully.

        Returns:
            None
        """

        try:
            free_gpu_resources()
            if self.is_model_loaded:
                prepare_kbvqa_model(only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
                self.col1.success("Model reloaded with updated settings and ready for inference.")
                self.update_prev_state
                st.session_state['button_label'] = "Reload Model"

            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    def process_new_image(self, image_key: str, image) -> None:
        """
        Processes a new uploaded image by creating an entry in the `images_data` dictionary in the application session 
        state.

        This dictionary stores information about each processed image, including:
            - `image`: The original image data.
            - `caption`: Generated caption for the image.
            - `detected_objects_str`: String representation of detected objects.
            - `qa_history`: List of questions and answers related to the image.
            - `analysis_done`: Flag indicating if analysis is complete.

        Args:
            image_key (str): Unique key for the image.
            image (obj): The uploaded image data.

        Returns:
            None
        """

        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    def analyze_image(self, image) -> Tuple[str, str, object]:
        """
        Analyzes the image using the KBVQA model.

        - Creates a copy of the image to avoid modifying the original.
        - Displays a "Analyzing the image .." message.
        - Calls KBVQA methods to generate a caption and detect objects.
        - Returns the generated caption, detected objects string, and image with bounding boxes.

        Args:
            image (obj): The image data to analyze.

        Returns:
            tuple: A tuple containing the generated caption, detected objects string, and image with bounding boxes.
        """

        free_gpu_resources()
        free_gpu_resources()
        img = copy.deepcopy(image)
        caption = st.session_state['kbvqa'].get_caption(img)
        image_with_boxes, detected_objects_str = st.session_state['kbvqa'].detect_objects(img)
        free_gpu_resources()
        return caption, detected_objects_str, image_with_boxes

    def add_to_qa_history(self, image_key: str, question: str, answer: str, prompt_length: int) -> None:
        """
        Adds a question-answer pair to the QA history of a specific image, to be used as a history tracker.

        Args:
            image_key (str): Unique key for the image.
            question (str): The question asked about the image.
            answer (str): The answer generated by the KBVQA model.
            prompt_length (int): The length of the prompt used for generating the answer.

        Returns:
            None
        """

        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer, prompt_length))

    def get_images_data(self) -> Dict:
        """
        Returns the dictionary containing processed image data from the session state.

        Returns:
            dict: The dictionary storing information about processed images.
        """

        return st.session_state['images_data']

    def update_image_data(self, image_key: str, caption: str, detected_objects_str: str, analysis_done: bool) -> None:
        """
        Updates the information stored for a specific image in the `images_data` dictionary in the application session 
        state.

        Args:
            image_key (str): Unique key for the image.
            caption (str): The generated caption for the image.
            detected_objects_str (str): String representation of detected objects.
            analysis_done (bool): Flag indicating if analysis of the image is complete.

        Returns:
            None
        """
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })

    def resize_image(self, image_input, new_width: Optional[int] = None, new_height: Optional[int] = None) -> Image:
        """
        Resizes an image. If only new_width is provided, the height is adjusted to maintain aspect ratio.
        If both new_width and new_height are provided, the image is resized to those dimensions.

        Args:
            image_input (PIL.Image.Image): The image to resize.
            new_width (int, optional): The target width of the image.
            new_height (int, optional): The target height of the image.

        Returns:
            PIL.Image.Image: The resized image.
        """

        img = copy.deepcopy(image_input)
        if isinstance(img, str):
            # Open the image from a file path
            image = Image.open(img)
        elif isinstance(img, Image.Image):
            # Use the image directly if it's already a PIL Image object
            image = img
        else:
            raise ValueError("image_input must be a file path or a PIL Image object")

        if new_width is not None and new_height is None:
            # Calculate new height to maintain aspect ratio
            original_width, original_height = image.size
            ratio = new_width / original_width
            new_height = int(original_height * ratio)
        elif new_width is None and new_height is not None:
            # Calculate new width to maintain aspect ratio
            original_width, original_height = image.size
            ratio = new_height / original_height
            new_width = int(original_width * ratio)
        elif new_width is None and new_height is None:
            raise ValueError("At least one of new_width or new_height must be provided")

        # Resize the image
        resized_image = image.resize((new_width, new_height))
        return resized_image

    def display_message(self, message: str, message_type: str) -> None:
        """
        Displays a message in the Streamlit interface based on the specified message type.

        Args:
            message (str): The message to display.
            message_type (str): The type of message ('warning', 'text', 'success', 'write', or 'error').

        Returns:
            None
        """

        if message_type == "warning":
            st.warning(message)
        elif message_type == "text":
            st.text(message)
        elif message_type == "success":
            st.success(message)
        elif message_type == "write":
            st.write(message)
        else:
            st.error("Message type unknown")

    @property
    def display_model_settings(self) -> None:
        """
        Displays a table of current model settings in the third column.

        Returns:
            None
        """
        self.col3.write("##### Current Model Settings:")
        data = [{'Setting': key, 'Value': str(value)} for key, value in st.session_state.items() if
                key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed',
                        'loading_in_progress', 'model_loaded', 'time_taken_to_load_model', 'images_data']]
        df = pd.DataFrame(data).reset_index(drop=True)
        return self.col3.write(df)

    def display_session_state(self, col) -> None:
        """
        Displays a table of the complete application state in the specified column.

        Args:
            col (streamlit.columns.Column): The Streamlit column to display the session state.

        Returns:
            None
        """

        col.write("Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data).reset_index(drop=True)
        col.write(df)