Spaces:
Sleeping
Sleeping
File size: 10,607 Bytes
53ad959 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
# Ultralytics YOLO 🚀, AGPL-3.0 license
from pathlib import Path
import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
from ultralytics.utils.plotting import output_to_target, plot_images
class PoseValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a pose model.
Example:
```python
from ultralytics.models.yolo.pose import PoseValidator
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
validator = PoseValidator(args=args)
validator()
```
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.sigma = None
self.kpt_shape = None
self.args.task = "pose"
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
batch = super().preprocess(batch)
batch["keypoints"] = batch["keypoints"].to(self.device).float()
return batch
def get_desc(self):
"""Returns description of evaluation metrics in string format."""
return ("%22s" + "%11s" * 10) % (
"Class",
"Images",
"Instances",
"Box(P",
"R",
"mAP50",
"mAP50-95)",
"Pose(P",
"R",
"mAP50",
"mAP50-95)",
)
def postprocess(self, preds):
"""Apply non-maximum suppression and return detections with high confidence scores."""
return ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=True,
agnostic=self.args.single_cls,
max_det=self.args.max_det,
nc=self.nc,
)
def init_metrics(self, model):
"""Initiate pose estimation metrics for YOLO model."""
super().init_metrics(model)
self.kpt_shape = self.data["kpt_shape"]
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[])
def _prepare_batch(self, si, batch):
"""Prepares a batch for processing by converting keypoints to float and moving to device."""
pbatch = super()._prepare_batch(si, batch)
kpts = batch["keypoints"][batch["batch_idx"] == si]
h, w = pbatch["imgsz"]
kpts = kpts.clone()
kpts[..., 0] *= w
kpts[..., 1] *= h
kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
pbatch["kpts"] = kpts
return pbatch
def _prepare_pred(self, pred, pbatch):
"""Prepares and scales keypoints in a batch for pose processing."""
predn = super()._prepare_pred(pred, pbatch)
nk = pbatch["kpts"].shape[1]
pred_kpts = predn[:, 6:].view(len(predn), nk, -1)
ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"])
return predn, pred_kpts
def update_metrics(self, preds, batch):
"""Metrics."""
for si, pred in enumerate(preds):
self.seen += 1
npr = len(pred)
stat = dict(
conf=torch.zeros(0, device=self.device),
pred_cls=torch.zeros(0, device=self.device),
tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device),
)
pbatch = self._prepare_batch(si, batch)
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
stat["target_cls"] = cls
if npr == 0:
if nl:
for k in self.stats.keys():
self.stats[k].append(stat[k])
if self.args.plots:
self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls)
continue
# Predictions
if self.args.single_cls:
pred[:, 5] = 0
predn, pred_kpts = self._prepare_pred(pred, pbatch)
stat["conf"] = predn[:, 4]
stat["pred_cls"] = predn[:, 5]
# Evaluate
if nl:
stat["tp"] = self._process_batch(predn, bbox, cls)
stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"])
if self.args.plots:
self.confusion_matrix.process_batch(predn, bbox, cls)
for k in self.stats.keys():
self.stats[k].append(stat[k])
# Save
if self.args.save_json:
self.pred_to_json(predn, batch["im_file"][si])
# if self.args.save_txt:
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None):
"""
Return correct prediction matrix.
Args:
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
pred_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing predicted keypoints.
51 corresponds to 17 keypoints each with 3 values.
gt_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing ground truth keypoints.
Returns:
torch.Tensor: Correct prediction matrix of shape [N, 10] for 10 IoU levels.
"""
if pred_kpts is not None and gt_kpts is not None:
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53
iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area)
else: # boxes
iou = box_iou(gt_bboxes, detections[:, :4])
return self.match_predictions(detections[:, 5], gt_cls, iou)
def plot_val_samples(self, batch, ni):
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
plot_images(
batch["img"],
batch["batch_idx"],
batch["cls"].squeeze(-1),
batch["bboxes"],
kpts=batch["keypoints"],
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0)
plot_images(
batch["img"],
*output_to_target(preds, max_det=self.args.max_det),
kpts=pred_kpts,
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""
stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(predn.tolist(), box.tolist()):
self.jdict.append(
{
"image_id": image_id,
"category_id": self.class_map[int(p[5])],
"bbox": [round(x, 3) for x in b],
"keypoints": p[6:],
"score": round(p[4], 5),
}
)
def eval_json(self, stats):
"""Evaluates object detection model using COCO JSON format."""
if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...")
try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
check_requirements("pycocotools>=2.0.6")
from pycocotools.coco import COCO # noqa
from pycocotools.cocoeval import COCOeval # noqa
for x in anno_json, pred_json:
assert x.is_file(), f"{x} file not found"
anno = COCO(str(anno_json)) # init annotations api
pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path)
for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]):
if self.is_coco:
eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval
eval.evaluate()
eval.accumulate()
eval.summarize()
idx = i * 4 + 2
stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[
:2
] # update mAP50-95 and mAP50
except Exception as e:
LOGGER.warning(f"pycocotools unable to run: {e}")
return stats
|