m7mdal7aj commited on
Commit
3a7f16d
1 Parent(s): e3c0d2e

Update my_model/detector/object_detection.py

Browse files
Files changed (1) hide show
  1. my_model/detector/object_detection.py +46 -43
my_model/detector/object_detection.py CHANGED
@@ -1,4 +1,4 @@
1
-
2
  import streamlit as st
3
  from transformers import AutoImageProcessor, AutoModelForObjectDetection
4
  import torch
@@ -11,18 +11,18 @@ from my_model.utilities.gen_utilities import get_image_path, get_model_path ,sho
11
 
12
  class ObjectDetector:
13
  """
14
- A class for detecting objects in images using models like Detic and YOLOv5.
15
-
16
- This class supports loading and using different object detection models to identify objects
17
- in images and draw bounding boxes around them.
18
-
19
- Attributes:
20
- model (torch.nn.Module): The loaded object detection model.
21
- processor (transformers.AutoImageProcessor): Processor for the Detic model.
22
- model_name (str): Name of the model used for detection.
23
- """
24
 
25
- def __init__(self):
26
  """
27
  Initializes the ObjectDetector class with default values.
28
  """
@@ -33,17 +33,17 @@ class ObjectDetector:
33
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
 
35
 
36
- def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
37
  """
38
- Load the specified object detection model.
39
 
40
- Args:
41
- model_name (str): Name of the model to load. Options are 'detic' and 'yolov5'.
42
- pretrained (bool): Boolean indicating if a pretrained model should be used.
43
- model_version (str): Version of the YOLOv5 model, applicable only when using YOLOv5.
44
 
45
- Raises:
46
- ValueError: If an unsupported model name is provided.
47
  """
48
 
49
  self.model_name = model_name
@@ -55,12 +55,15 @@ class ObjectDetector:
55
  raise ValueError(f"Unsupported model name: {model_name}")
56
 
57
 
58
- def _load_detic_model(self, pretrained):
59
  """
60
  Load the Detic model.
61
 
62
  Args:
63
  pretrained (bool): If True, load a pretrained model.
 
 
 
64
  """
65
 
66
  try:
@@ -72,13 +75,15 @@ class ObjectDetector:
72
  raise
73
 
74
 
75
- def _load_yolov5_model(self, pretrained, model_version):
76
  """
77
  Load the YOLOv5 model.
78
 
79
  Args:
80
  pretrained (bool): If True, load a pretrained model.
81
  model_version (str): Version of the YOLOv5 model.
 
 
82
  """
83
 
84
  try:
@@ -92,13 +97,16 @@ class ObjectDetector:
92
  raise
93
 
94
 
95
- def process_image(self, image_input):
96
  """
97
  Process the image from the given path or file-like object.
 
98
  Args:
99
- image_input (str or file-like object): Path to the image file or a file-like object.
 
100
  Returns:
101
  Image.Image: Processed image in RGB format.
 
102
  Raises:
103
  Exception: If an error occurs during image processing.
104
  """
@@ -119,16 +127,17 @@ class ObjectDetector:
119
  raise
120
 
121
 
122
- def detect_objects(self, image, threshold=0.4):
123
  """
124
  Detect objects in the given image using the loaded model.
125
 
126
  Args:
127
  image (Image.Image): Image in which to detect objects.
128
- threshold (float): Model detection confidence.
129
 
130
  Returns:
131
- tuple: A tuple containing a string representation and a list of detected objects.
 
132
 
133
  Raises:
134
  ValueError: If the model is not loaded or the model name is unsupported.
@@ -142,7 +151,7 @@ class ObjectDetector:
142
  raise ValueError("Model not loaded or unsupported model name")
143
 
144
 
145
- def _detect_with_detic(self, image, threshold):
146
  """
147
  Detect objects using the Detic model.
148
 
@@ -151,8 +160,8 @@ class ObjectDetector:
151
  threshold (float): The confidence threshold for detections.
152
 
153
  Returns:
154
- tuple: A tuple containing a string representation and a list of detected objects.
155
- Each object in the list is represented as a tuple (label_name, box_rounded, certainty).
156
  """
157
 
158
  inputs = self.processor(images=image, return_tensors="pt")
@@ -171,7 +180,7 @@ class ObjectDetector:
171
  return detected_objects_str, detected_objects_list
172
 
173
 
174
- def _detect_with_yolov5(self, image, threshold):
175
  """
176
  Detect objects using the YOLOv5 model.
177
 
@@ -180,8 +189,8 @@ class ObjectDetector:
180
  threshold (float): The confidence threshold for detections.
181
 
182
  Returns:
183
- tuple: A tuple containing a string representation and a list of detected objects.
184
- Each object in the list is represented as a tuple (label_name, box_rounded, certainty).
185
  """
186
 
187
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
@@ -198,13 +207,13 @@ class ObjectDetector:
198
  return detected_objects_str, detected_objects_list
199
 
200
 
201
- def draw_boxes(self, image, detected_objects, show_confidence=True):
202
  """
203
  Draw bounding boxes around detected objects in the image.
204
 
205
  Args:
206
  image (Image.Image): Image on which to draw.
207
- detected_objects (list): List of detected objects.
208
  show_confidence (bool): Whether to show confidence scores.
209
 
210
  Returns:
@@ -232,7 +241,7 @@ class ObjectDetector:
232
  return image
233
 
234
 
235
- def detect_and_draw_objects(image_path, model_type='yolov5', threshold=0.2, show_confidence=True):
236
  """
237
  Detects objects in an image, draws bounding boxes around them, and returns the processed image and a string description.
238
 
@@ -243,7 +252,7 @@ def detect_and_draw_objects(image_path, model_type='yolov5', threshold=0.2, show
243
  show_confidence (bool): Whether to show confidence scores on the output image.
244
 
245
  Returns:
246
- tuple: A tuple containing the processed Image.Image and a string of detected objects.
247
  """
248
 
249
  detector = ObjectDetector()
@@ -252,9 +261,3 @@ def detect_and_draw_objects(image_path, model_type='yolov5', threshold=0.2, show
252
  detected_objects_string, detected_objects_list = detector.detect_objects(image, threshold=threshold)
253
  image_with_boxes = detector.draw_boxes(image, detected_objects_list, show_confidence=show_confidence)
254
  return image_with_boxes, detected_objects_string
255
-
256
-
257
-
258
- if __name__ == "__main__":
259
- pass
260
-
 
1
+ from typing import Union, Optional, List, Tuple
2
  import streamlit as st
3
  from transformers import AutoImageProcessor, AutoModelForObjectDetection
4
  import torch
 
11
 
12
  class ObjectDetector:
13
  """
14
+ A class for detecting objects in images using models like Detic and YOLOv5.
15
+ This class supports loading and using different object detection models to identify objects
16
+ in images and draw bounding boxes around them.
17
+
18
+ Attributes:
19
+ model (torch.nn.Module or None): The loaded object detection model.
20
+ processor (transformers.AutoImageProcessor or None): Processor for the Detic model.
21
+ model_name (str or None): Name of the model used for detection.
22
+ device (str): Device to use for computation ('cuda' if available, otherwise 'cpu').
23
+ """
24
 
25
+ def __init__(self) -> None:
26
  """
27
  Initializes the ObjectDetector class with default values.
28
  """
 
33
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
 
35
 
36
+ def load_model(self, model_name: str = 'detic', pretrained: bool = True, model_version: str = 'yolov5s') -> None:
37
  """
38
+ Load the specified object detection model.
39
 
40
+ Args:
41
+ model_name (str): Name of the model to load. Options are 'detic' and 'yolov5'.
42
+ pretrained (bool): Boolean indicating if a pretrained model should be used.
43
+ model_version (str): Version of the YOLOv5 model, applicable only when using YOLOv5.
44
 
45
+ Raises:
46
+ ValueError: If an unsupported model name is provided.
47
  """
48
 
49
  self.model_name = model_name
 
55
  raise ValueError(f"Unsupported model name: {model_name}")
56
 
57
 
58
+ def _load_detic_model(self, pretrained: bool) -> None:
59
  """
60
  Load the Detic model.
61
 
62
  Args:
63
  pretrained (bool): If True, load a pretrained model.
64
+
65
+ Raises:
66
+ Exception: If an error occurs during model loading.
67
  """
68
 
69
  try:
 
75
  raise
76
 
77
 
78
+ def _load_yolov5_model(self, pretrained: bool, model_version: str) -> None:
79
  """
80
  Load the YOLOv5 model.
81
 
82
  Args:
83
  pretrained (bool): If True, load a pretrained model.
84
  model_version (str): Version of the YOLOv5 model.
85
+ Raises:
86
+ Exception: If an error occurs during model loading.
87
  """
88
 
89
  try:
 
97
  raise
98
 
99
 
100
+ def process_image(self, image_input: Union[str, io.IOBase, Image.Image]) -> Image.Image:
101
  """
102
  Process the image from the given path or file-like object.
103
+
104
  Args:
105
+ image_input (Union[str, io.IOBase, Image.Image]): Path to the image file, a file-like object, or a PIL Image.
106
+
107
  Returns:
108
  Image.Image: Processed image in RGB format.
109
+
110
  Raises:
111
  Exception: If an error occurs during image processing.
112
  """
 
127
  raise
128
 
129
 
130
+ def detect_objects(self, image: Image.Image, threshold: float = 0.4) -> Tuple[str, List[Tuple[str, List[float], float]]]:
131
  """
132
  Detect objects in the given image using the loaded model.
133
 
134
  Args:
135
  image (Image.Image): Image in which to detect objects.
136
+ threshold (float): Model detection confidence threshold.
137
 
138
  Returns:
139
+ Tuple[str, List[Tuple[str, List[float], float]]]: A tuple containing a string representation and a list of detected objects.
140
+ Each object in the list is represented as a tuple (label_name, box_rounded, certainty).
141
 
142
  Raises:
143
  ValueError: If the model is not loaded or the model name is unsupported.
 
151
  raise ValueError("Model not loaded or unsupported model name")
152
 
153
 
154
+ def _detect_with_detic(self, image: Image.Image, threshold: float) -> Tuple[str, List[Tuple[str, List[float], float]]]:
155
  """
156
  Detect objects using the Detic model.
157
 
 
160
  threshold (float): The confidence threshold for detections.
161
 
162
  Returns:
163
+ Tuple[str, List[Tuple[str, List[float], float]]]: A tuple containing a string representation and a list of detected objects.
164
+ Each object in the list is represented as a tuple (label_name, box_rounded, certainty).
165
  """
166
 
167
  inputs = self.processor(images=image, return_tensors="pt")
 
180
  return detected_objects_str, detected_objects_list
181
 
182
 
183
+ def _detect_with_yolov5(self, image: Image.Image, threshold: float) -> Tuple[str, List[Tuple[str, List[float], float]]]:
184
  """
185
  Detect objects using the YOLOv5 model.
186
 
 
189
  threshold (float): The confidence threshold for detections.
190
 
191
  Returns:
192
+ Tuple[str, List[Tuple[str, List[float], float]]]: A tuple containing a string representation and a list of detected objects.
193
+ Each object in the list is represented as a tuple (label_name, box_rounded, certainty).
194
  """
195
 
196
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
 
207
  return detected_objects_str, detected_objects_list
208
 
209
 
210
+ def draw_boxes(self, image: Image.Image, detected_objects: List[Tuple[str, List[float], float]], show_confidence: bool = True) -> Image.Image:
211
  """
212
  Draw bounding boxes around detected objects in the image.
213
 
214
  Args:
215
  image (Image.Image): Image on which to draw.
216
+ detected_objects (List[Tuple[str, List[float], float]]): List of detected objects.
217
  show_confidence (bool): Whether to show confidence scores.
218
 
219
  Returns:
 
241
  return image
242
 
243
 
244
+ def detect_and_draw_objects(image_path: str, model_type: str = 'yolov5', threshold: float = 0.2, show_confidence: bool = True) -> Tuple[Image.Image, str]:
245
  """
246
  Detects objects in an image, draws bounding boxes around them, and returns the processed image and a string description.
247
 
 
252
  show_confidence (bool): Whether to show confidence scores on the output image.
253
 
254
  Returns:
255
+ Tuple[Image.Image, str]: A tuple containing the processed Image.Image and a string of detected objects.
256
  """
257
 
258
  detector = ObjectDetector()
 
261
  detected_objects_string, detected_objects_list = detector.detect_objects(image, threshold=threshold)
262
  image_with_boxes = detector.draw_boxes(image, detected_objects_list, show_confidence=show_confidence)
263
  return image_with_boxes, detected_objects_string