Spaces:
Sleeping
Sleeping
# Ultralytics YOLO π, AGPL-3.0 license | |
from pathlib import Path | |
from ultralytics.engine.model import Model | |
from ultralytics.models import yolo | |
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel | |
from ultralytics.utils import yaml_load, ROOT | |
class YOLO(Model): | |
"""YOLO (You Only Look Once) object detection model.""" | |
def __init__(self, model="yolov8n.pt", task=None, verbose=False): | |
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" | |
path = Path(model) | |
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model | |
new_instance = YOLOWorld(path) | |
self.__class__ = type(new_instance) | |
self.__dict__ = new_instance.__dict__ | |
elif "yolov10" in path.stem: | |
from ultralytics import YOLOv10 | |
new_instance = YOLOv10(path) | |
self.__class__ = type(new_instance) | |
self.__dict__ = new_instance.__dict__ | |
else: | |
# Continue with default YOLO initialization | |
super().__init__(model=model, task=task, verbose=verbose) | |
def task_map(self): | |
"""Map head to model, trainer, validator, and predictor classes.""" | |
return { | |
"classify": { | |
"model": ClassificationModel, | |
"trainer": yolo.classify.ClassificationTrainer, | |
"validator": yolo.classify.ClassificationValidator, | |
"predictor": yolo.classify.ClassificationPredictor, | |
}, | |
"detect": { | |
"model": DetectionModel, | |
"trainer": yolo.detect.DetectionTrainer, | |
"validator": yolo.detect.DetectionValidator, | |
"predictor": yolo.detect.DetectionPredictor, | |
}, | |
"segment": { | |
"model": SegmentationModel, | |
"trainer": yolo.segment.SegmentationTrainer, | |
"validator": yolo.segment.SegmentationValidator, | |
"predictor": yolo.segment.SegmentationPredictor, | |
}, | |
"pose": { | |
"model": PoseModel, | |
"trainer": yolo.pose.PoseTrainer, | |
"validator": yolo.pose.PoseValidator, | |
"predictor": yolo.pose.PosePredictor, | |
}, | |
"obb": { | |
"model": OBBModel, | |
"trainer": yolo.obb.OBBTrainer, | |
"validator": yolo.obb.OBBValidator, | |
"predictor": yolo.obb.OBBPredictor, | |
}, | |
} | |
class YOLOWorld(Model): | |
"""YOLO-World object detection model.""" | |
def __init__(self, model="yolov8s-world.pt") -> None: | |
""" | |
Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats. | |
Args: | |
model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'. | |
""" | |
super().__init__(model=model, task="detect") | |
# Assign default COCO class names when there are no custom names | |
if not hasattr(self.model, "names"): | |
self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") | |
def task_map(self): | |
"""Map head to model, validator, and predictor classes.""" | |
return { | |
"detect": { | |
"model": WorldModel, | |
"validator": yolo.detect.DetectionValidator, | |
"predictor": yolo.detect.DetectionPredictor, | |
} | |
} | |
def set_classes(self, classes): | |
""" | |
Set classes. | |
Args: | |
classes (List(str)): A list of categories i.e ["person"]. | |
""" | |
self.model.set_classes(classes) | |
# Remove background if it's given | |
background = " " | |
if background in classes: | |
classes.remove(background) | |
self.model.names = classes | |
# Reset method class names | |
# self.predictor = None # reset predictor otherwise old names remain | |
if self.predictor: | |
self.predictor.model.names = classes | |