added object detection to the space UI
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ import scipy
|
|
6 |
from PIL import Image
|
7 |
import torch.nn as nn
|
8 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
|
9 |
-
from
|
10 |
|
11 |
def load_caption_model(blip2=False, instructblip=True):
|
12 |
|
@@ -65,3 +65,56 @@ if st.button("Get Answer"):
|
|
65 |
st.write(answer)
|
66 |
else:
|
67 |
st.write("Please upload an image and enter a question.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from PIL import Image
|
7 |
import torch.nn as nn
|
8 |
from transformers import Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration
|
9 |
+
from my_model.object_detection import ObjectDetector
|
10 |
|
11 |
def load_caption_model(blip2=False, instructblip=True):
|
12 |
|
|
|
65 |
st.write(answer)
|
66 |
else:
|
67 |
st.write("Please upload an image and enter a question.")
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
# Object Detection
|
75 |
+
|
76 |
+
# Object Detection UI in the sidebar
|
77 |
+
st.sidebar.title("Object Detection")
|
78 |
+
# Dropdown to select the model
|
79 |
+
detect_model = st.sidebar.selectbox("Choose a model for object detection:", ["detic", "yolov5"])
|
80 |
+
# Slider for threshold with default values based on the model
|
81 |
+
threshold = st.sidebar.slider("Select Detection Threshold", 0.1, 0.9, 0.2 if detect_model == "yolov5" else 0.4)
|
82 |
+
# Button to trigger object detection
|
83 |
+
detect_button = st.sidebar.button("Detect Objects")
|
84 |
+
|
85 |
+
|
86 |
+
def perform_object_detection(image, model_name, threshold):
|
87 |
+
"""
|
88 |
+
Perform object detection on the given image using the specified model and threshold.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
image (PIL.Image): The image on which to perform object detection.
|
92 |
+
model_name (str): The name of the object detection model to use.
|
93 |
+
threshold (float): The threshold for object detection.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
PIL.Image, str: The image with drawn bounding boxes and a string of detected objects.
|
97 |
+
"""
|
98 |
+
# Initialize the ObjectDetector
|
99 |
+
detector = ObjectDetector()
|
100 |
+
# Load the specified model
|
101 |
+
detector.load_model(model_name)
|
102 |
+
# Perform object detection
|
103 |
+
processed_image, detected_objects = detector.detect_objects(image, threshold)
|
104 |
+
return processed_image, detected_objects
|
105 |
+
|
106 |
+
# Check if the 'Detect Objects' button was clicked
|
107 |
+
if detect_button:
|
108 |
+
if image is not None:
|
109 |
+
# Open the uploaded image
|
110 |
+
image = Image.open(image)
|
111 |
+
# Display the original image
|
112 |
+
st.image(image, use_column_width=True, caption="Original Image")
|
113 |
+
# Perform object detection
|
114 |
+
processed_image, detected_objects = perform_object_detection(image, detect_model, threshold)
|
115 |
+
# Display the image with detected objects
|
116 |
+
st.image(processed_image, use_column_width=True, caption="Image with Detected Objects")
|
117 |
+
# Display the detected objects
|
118 |
+
st.write(detected_objects)
|
119 |
+
else:
|
120 |
+
st.write("Please upload an image for object detection.")
|