File size: 5,908 Bytes
18d1852
e46d486
18d1852
753c201
1d51bf5
18d1852
 
 
 
 
 
 
 
 
74d450c
77e6994
4170a5f
74d450c
09ff02d
2152f1f
2957e90
1b71503
 
22d0357
e4eee9a
f72214b
18d1852
d80fd56
 
 
 
 
18d1852
 
 
2957e90
18d1852
 
d182243
d0a09f4
d6a4897
5dd58f8
d182243
18d1852
 
3689b26
18d1852
 
 
 
d0e9fe6
753c201
 
 
d0e9fe6
cc825df
d9364fd
d0e9fe6
f72214b
 
9ccbf8e
753c201
 
f72214b
 
 
 
 
 
 
 
753c201
18d1852
 
 
 
 
 
 
cd3678b
18d1852
 
 
1fc0405
18d1852
105e89e
18d1852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d182243
18d1852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pandas as pd
import copy
import streamlit as st
from my_model.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model


class StateManager:

    def initialize_state(self):
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if "button_lablel" not in st.session_state: 
            st.session_state['button_lablel'] = "Load Model"
        if "previous_state" not in st.session_state:
            st.session_state['previous_state'] = {}
            


    def set_up_widgets(self):
        st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
        detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
        default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
        self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level')

    def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):

        return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
        

    def check_settings_changed(self, current_selected_method, current_detection_model, current_confidence_level):
        return (st.session_state['model_settings']['detection_model'] != current_detection_model or
                st.session_state['model_settings']['confidence_level'] != current_confidence_level or
                st.session_state['model_settings']['selected_method'] != current_selected_method)

    def display_model_settings(self):
        st.write("#### Current Model Settings:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa']]
        df = pd.DataFrame(data)
        styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
        st.table(styled_df)

    def display_session_state(self):
        st.write("Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data)
        st.table(df)

    def load_model(self):
        """Load the KBVQA model with specified settings."""
        try:
            free_gpu_resources()
            st.text("Loading the model, this should take no more than a few minutes, please wait...")
            st.session_state['kbvqa'] = prepare_kbvqa_model()
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            #self.update_model_settings(detection_model, confidence_level)
            # Update the previous state with current session state values
            st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
            st.session_state['button_lablel'] = "Reload Model"
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error loading model: {e}")

    # Function to check if any session state values have changed
    def has_state_changed(self):
        for key in st.session_state['previous_state']:
            if st.session_state[key] != st.session_state['previous_state'][key]:
                return True  # Found a change
            else: return False  # No changes found   

    def get_model(self):
        """Retrieve the KBVQA model from the session state."""
        return st.session_state.get('kbvqa', None)

    def is_model_loaded(self):
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None

    def reload_detection_model(self):
        try:
            free_gpu_resources()
            if self.is_model_loaded():
                prepare_kbvqa_model(only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = confidence_level
                #self.update_model_settings(detection_model, confidence_level)
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    # New methods to be added
    def process_new_image(self, image_key, image, kbvqa):
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    def analyze_image(self, image, kbvqa):
        img = copy.deepcopy(image)
        st.text("Analyzing the image .. ")
        caption = kbvqa.get_caption(img)
        image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
        return caption, detected_objects_str, image_with_boxes

    def add_to_qa_history(self, image_key, question, answer):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer))

    def get_images_data(self):
        return st.session_state['images_data']

    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })