Update my_model/state_manager.py
Browse files- my_model/state_manager.py +11 -14
my_model/state_manager.py
CHANGED
@@ -41,22 +41,22 @@ class StateManager:
|
|
41 |
"""
|
42 |
Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
|
43 |
"""
|
44 |
-
|
45 |
-
self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method', disabled=self.is_widget_disabled)
|
46 |
-
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model', disabled=self.is_widget_disabled)
|
47 |
default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
|
48 |
-
self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level', col=self.col1
|
49 |
|
50 |
# Conditional display of model settings
|
51 |
|
52 |
|
53 |
-
show_model_settings = self.col3.checkbox("Show Model Settings", False)
|
54 |
if show_model_settings:
|
55 |
self.display_model_settings()
|
56 |
|
57 |
|
58 |
|
59 |
-
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None
|
60 |
"""
|
61 |
Creates a slider widget with the specified parameters, optionally placing it in a specific column.
|
62 |
|
@@ -71,9 +71,9 @@ class StateManager:
|
|
71 |
"""
|
72 |
|
73 |
if col is None:
|
74 |
-
return st.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=
|
75 |
else:
|
76 |
-
return col.slider(text, min_value, max_value, value, step, key=slider_key_name, disabled=
|
77 |
|
78 |
@property
|
79 |
def is_widget_disabled(self):
|
@@ -129,14 +129,13 @@ class StateManager:
|
|
129 |
"""
|
130 |
|
131 |
try:
|
132 |
-
|
133 |
free_gpu_resources()
|
134 |
st.session_state['kbvqa'] = prepare_kbvqa_model()
|
135 |
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
|
136 |
# Update the previous state with current session state values
|
137 |
st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
|
138 |
st.session_state['model_loaded'] = True
|
139 |
-
st.session_state['loading_in_progress'] = False
|
140 |
st.session_state['button_label'] = "Reload Model"
|
141 |
free_gpu_resources()
|
142 |
|
@@ -145,7 +144,6 @@ class StateManager:
|
|
145 |
|
146 |
def force_reload_model(self):
|
147 |
try:
|
148 |
-
st.session_state['loading_in_progress'] = True
|
149 |
self.delete_model()
|
150 |
free_gpu_resources()
|
151 |
st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
|
@@ -153,7 +151,6 @@ class StateManager:
|
|
153 |
# Update the previous state with current session state values
|
154 |
st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
|
155 |
st.session_state['model_loaded'] = True
|
156 |
-
st.session_state['loading_in_progress'] = False
|
157 |
free_gpu_resources()
|
158 |
except Exception as e:
|
159 |
st.error(f"Error reloading model: {e}")
|
@@ -223,11 +220,11 @@ class StateManager:
|
|
223 |
try:
|
224 |
free_gpu_resources()
|
225 |
if self.is_model_loaded():
|
226 |
-
|
227 |
prepare_kbvqa_model(only_reload_detection_model=True)
|
228 |
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
|
229 |
self.col1.success("Model reloaded with updated settings and ready for inference.")
|
230 |
-
|
231 |
free_gpu_resources()
|
232 |
except Exception as e:
|
233 |
st.error(f"Error reloading detection model: {e}")
|
|
|
41 |
"""
|
42 |
Sets up user interface widgets for selecting models, settings, and displaying model settings conditionally.
|
43 |
"""
|
44 |
+
|
45 |
+
self.col1.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method', on_click=self.disable_widgets, disabled=self.is_widget_disabled)
|
46 |
+
detection_model = self.col1.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model', on_click=self.disable_widgets, disabled=self.is_widget_disabled)
|
47 |
default_confidence = 0.2 if st.session_state.detection_model == "yolov5" else 0.4
|
48 |
+
self.set_slider_value(text="Select minimum detection confidence level", min_value=0.1, max_value=0.9, value=default_confidence, step=0.1, slider_key_name='confidence_level', col=self.col1)
|
49 |
|
50 |
# Conditional display of model settings
|
51 |
|
52 |
|
53 |
+
show_model_settings = self.col3.checkbox("Show Model Settings", False, on_click=self.disable_widgets, disabled=self.is_widget_disabled)
|
54 |
if show_model_settings:
|
55 |
self.display_model_settings()
|
56 |
|
57 |
|
58 |
|
59 |
+
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name, col=None):
|
60 |
"""
|
61 |
Creates a slider widget with the specified parameters, optionally placing it in a specific column.
|
62 |
|
|
|
71 |
"""
|
72 |
|
73 |
if col is None:
|
74 |
+
return st.slider(text, min_value, max_value, value, step, key=slider_key_name, on_click=self.disable_widgets, disabled=self.is_widget_disabledd)
|
75 |
else:
|
76 |
+
return col.slider(text, min_value, max_value, value, step, key=slider_key_name, on_click=self.disable_widgets, disabled=self.is_widget_disabled)
|
77 |
|
78 |
@property
|
79 |
def is_widget_disabled(self):
|
|
|
129 |
"""
|
130 |
|
131 |
try:
|
132 |
+
|
133 |
free_gpu_resources()
|
134 |
st.session_state['kbvqa'] = prepare_kbvqa_model()
|
135 |
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
|
136 |
# Update the previous state with current session state values
|
137 |
st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
|
138 |
st.session_state['model_loaded'] = True
|
|
|
139 |
st.session_state['button_label'] = "Reload Model"
|
140 |
free_gpu_resources()
|
141 |
|
|
|
144 |
|
145 |
def force_reload_model(self):
|
146 |
try:
|
|
|
147 |
self.delete_model()
|
148 |
free_gpu_resources()
|
149 |
st.session_state['kbvqa'] = prepare_kbvqa_model(force_reload=True)
|
|
|
151 |
# Update the previous state with current session state values
|
152 |
st.session_state['previous_state'] = {'method': st.session_state.method, 'detection_model': st.session_state.detection_model, 'confidence_level': st.session_state.confidence_level}
|
153 |
st.session_state['model_loaded'] = True
|
|
|
154 |
free_gpu_resources()
|
155 |
except Exception as e:
|
156 |
st.error(f"Error reloading model: {e}")
|
|
|
220 |
try:
|
221 |
free_gpu_resources()
|
222 |
if self.is_model_loaded():
|
223 |
+
|
224 |
prepare_kbvqa_model(only_reload_detection_model=True)
|
225 |
st.session_state['kbvqa'].detection_confidence = st.session_state.confidence_level
|
226 |
self.col1.success("Model reloaded with updated settings and ready for inference.")
|
227 |
+
|
228 |
free_gpu_resources()
|
229 |
except Exception as e:
|
230 |
st.error(f"Error reloading detection model: {e}")
|