xiaoming32236046's picture
Upload 325 files
53ad959 verified
raw
history blame
4.16 kB
# Ultralytics YOLO πŸš€, AGPL-3.0 license
from pathlib import Path
from ultralytics.engine.model import Model
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
from ultralytics.utils import yaml_load, ROOT
class YOLO(Model):
"""YOLO (You Only Look Once) object detection model."""
def __init__(self, model="yolov8n.pt", task=None, verbose=False):
"""Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
path = Path(model)
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
new_instance = YOLOWorld(path)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
elif "yolov10" in path.stem:
from ultralytics import YOLOv10
new_instance = YOLOv10(path)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
else:
# Continue with default YOLO initialization
super().__init__(model=model, task=task, verbose=verbose)
@property
def task_map(self):
"""Map head to model, trainer, validator, and predictor classes."""
return {
"classify": {
"model": ClassificationModel,
"trainer": yolo.classify.ClassificationTrainer,
"validator": yolo.classify.ClassificationValidator,
"predictor": yolo.classify.ClassificationPredictor,
},
"detect": {
"model": DetectionModel,
"trainer": yolo.detect.DetectionTrainer,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
},
"segment": {
"model": SegmentationModel,
"trainer": yolo.segment.SegmentationTrainer,
"validator": yolo.segment.SegmentationValidator,
"predictor": yolo.segment.SegmentationPredictor,
},
"pose": {
"model": PoseModel,
"trainer": yolo.pose.PoseTrainer,
"validator": yolo.pose.PoseValidator,
"predictor": yolo.pose.PosePredictor,
},
"obb": {
"model": OBBModel,
"trainer": yolo.obb.OBBTrainer,
"validator": yolo.obb.OBBValidator,
"predictor": yolo.obb.OBBPredictor,
},
}
class YOLOWorld(Model):
"""YOLO-World object detection model."""
def __init__(self, model="yolov8s-world.pt") -> None:
"""
Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
Args:
model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
"""
super().__init__(model=model, task="detect")
# Assign default COCO class names when there are no custom names
if not hasattr(self.model, "names"):
self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
@property
def task_map(self):
"""Map head to model, validator, and predictor classes."""
return {
"detect": {
"model": WorldModel,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
}
}
def set_classes(self, classes):
"""
Set classes.
Args:
classes (List(str)): A list of categories i.e ["person"].
"""
self.model.set_classes(classes)
# Remove background if it's given
background = " "
if background in classes:
classes.remove(background)
self.model.names = classes
# Reset method class names
# self.predictor = None # reset predictor otherwise old names remain
if self.predictor:
self.predictor.model.names = classes