File size: 4,678 Bytes
18d1852
e46d486
18d1852
753c201
1d51bf5
18d1852
 
 
bcb92d8
 
18d1852
 
 
 
7d71f1b
bcb92d8
c03044f
bcb92d8
18d1852
 
c03044f
bcb92d8
2957e90
f51ceea
18d1852
 
d80fd56
 
 
 
 
18d1852
 
 
2957e90
18d1852
 
 
d0a09f4
 
18d1852
 
d0a09f4
18d1852
 
 
 
d0e9fe6
753c201
 
 
d0e9fe6
d9364fd
 
d0e9fe6
 
753c201
 
 
 
18d1852
 
 
 
 
 
 
 
 
 
 
 
 
105e89e
18d1852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import pandas as pd
import copy
import streamlit as st
from my_model.gen_utilities import free_gpu_resources
from my_model.KBVQA import KBVQA, prepare_kbvqa_model


class StateManager:
    def __init__(self):
        self.initialize_state()

    def initialize_state(self):
        if 'images_data' not in st.session_state:
            st.session_state['images_data'] = {}
        if 'method' not in st.session_state:
            st.session_state['method'] = None
        if 'detection_model' not in st.session_state:
            st.session_state['detection_model'] = None
        if 'kbvqa' not in st.session_state:
            st.session_state['kbvqa'] = None
        if 'confidence_level' not in st.session_state:
            st.session_state['confidence_level'] = None




    def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):

        return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
        

    def check_settings_changed(self, current_selected_method, current_detection_model, current_confidence_level):
        return (st.session_state['model_settings']['detection_model'] != current_detection_model or
                st.session_state['model_settings']['confidence_level'] != current_confidence_level or
                st.session_state['model_settings']['selected_method'] != current_selected_method)

    def display_model_settings(self):
        st.write("### Current Model Settings:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa']]
        st.table(pd.DataFrame(data))

    def display_session_state(self):
        st.write("### Current Model:")
        data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
        df = pd.DataFrame(data)
        st.table(df)

    def load_model(self):
        """Load the KBVQA model with specified settings."""
        try:
            free_gpu_resources()
            st.text("Loading the model, this should take no more than a few minutes, please wait...")
            st.session_state['kbvqa'] = prepare_kbvqa_model(st.session_state.detection_model)
            st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
            #self.update_model_settings(detection_model, confidence_level)
            st.text("Model is ready for inference.")
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error loading model: {e}")    

    def get_model(self):
        """Retrieve the KBVQA model from the session state."""
        return st.session_state.get('kbvqa', None)

    def is_model_loaded(self):
        return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None

    def reload_detection_model(self, detection_model, confidence_level):
        try:
            free_gpu_resources()
            if self.is_model_loaded():
                prepare_kbvqa_model(detection_model, only_reload_detection_model=True)
                st.session_state['kbvqa'].detection_confidence = confidence_level
                #self.update_model_settings(detection_model, confidence_level)
            free_gpu_resources()
        except Exception as e:
            st.error(f"Error reloading detection model: {e}")

    # New methods to be added
    def process_new_image(self, image_key, image, kbvqa):
        if image_key not in st.session_state['images_data']:
            st.session_state['images_data'][image_key] = {
                'image': image,
                'caption': '',
                'detected_objects_str': '',
                'qa_history': [],
                'analysis_done': False
            }

    def analyze_image(self, image, kbvqa):
        img = copy.deepcopy(image)
        caption = kbvqa.get_caption(img)
        image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
        return caption, detected_objects_str, image_with_boxes

    def add_to_qa_history(self, image_key, question, answer):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key]['qa_history'].append((question, answer))

    def get_images_data(self):
        return st.session_state['images_data']

    def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
        if image_key in st.session_state['images_data']:
            st.session_state['images_data'][image_key].update({
                'caption': caption,
                'detected_objects_str': detected_objects_str,
                'analysis_done': analysis_done
            })