m7mdal7aj commited on
Commit
9c23df5
1 Parent(s): 141a983

Update my_model/object_detection.py

Browse files
Files changed (1) hide show
  1. my_model/object_detection.py +5 -4
my_model/object_detection.py CHANGED
@@ -30,6 +30,7 @@ class ObjectDetector:
30
  self.model = None
31
  self.processor = None
32
  self.model_name = None
 
33
 
34
 
35
  def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
@@ -64,8 +65,8 @@ class ObjectDetector:
64
 
65
  try:
66
  model_path = get_model_path('deformable-detr-detic')
67
- self.processor = AutoImageProcessor.from_pretrained(model_path)
68
- self.model = AutoModelForObjectDetection.from_pretrained(model_path)
69
  except Exception as e:
70
  print(f"Error loading Detic model: {e}")
71
  raise
@@ -83,9 +84,9 @@ class ObjectDetector:
83
  try:
84
  model_path = get_model_path ('yolov5')
85
  if model_path and os.path.exists(model_path):
86
- self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local')
87
  else:
88
- self.model = torch.hub.load('ultralytics/yolov5', model_version, pretrained=pretrained)
89
  except Exception as e:
90
  print(f"Error loading YOLOv5 model: {e}")
91
  raise
 
30
  self.model = None
31
  self.processor = None
32
  self.model_name = None
33
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
 
35
 
36
  def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
 
65
 
66
  try:
67
  model_path = get_model_path('deformable-detr-detic')
68
+ self.processor = AutoImageProcessor.from_pretrained(model_path, device_map = self.device)
69
+ self.model = AutoModelForObjectDetection.from_pretrained(model_path, device_map = self.device)
70
  except Exception as e:
71
  print(f"Error loading Detic model: {e}")
72
  raise
 
84
  try:
85
  model_path = get_model_path ('yolov5')
86
  if model_path and os.path.exists(model_path):
87
+ self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local', device_map = self.device)
88
  else:
89
+ self.model = torch.hub.load('ultralytics/yolov5', model_version, pretrained=pretrained, device_map = self.device)
90
  except Exception as e:
91
  print(f"Error loading YOLOv5 model: {e}")
92
  raise