Spaces:
Sleeping
Sleeping
File size: 6,306 Bytes
53ad959 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Ultralytics YOLO 🚀, AGPL-3.0 license
import math
import random
from copy import copy
import numpy as np
import torch.nn as nn
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
class DetectionTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionTrainer
args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
trainer = DetectionTrainer(overrides=args)
trainer.train()
```
"""
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Construct and return dataloader."""
assert mode in ["train", "val"]
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode, batch_size)
shuffle = mode == "train"
if getattr(dataset, "rect", False) and shuffle:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
workers = self.args.workers if mode == "train" else self.args.workers * 2
return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
if self.args.multi_scale:
imgs = batch["img"]
sz = (
random.randrange(self.args.imgsz * 0.5, self.args.imgsz * 1.5 + self.stride)
// self.stride
* self.stride
) # size
sf = sz / max(imgs.shape[2:]) # scale factor
if sf != 1:
ns = [
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
batch["img"] = imgs
return batch
def set_model_attributes(self):
"""Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)."""
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data["nc"] # attach number of classes to model
self.model.names = self.data["names"] # attach class names to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Returns a DetectionValidator for YOLO model validation."""
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
return yolo.detect.DetectionValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def label_loss_items(self, loss_items=None, prefix="train"):
"""
Returns a loss dict with labelled training loss items tensor.
Not needed for classification but necessary for segmentation & detection
"""
keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
return dict(zip(keys, loss_items))
else:
return keys
def progress_string(self):
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(
images=batch["img"],
batch_idx=batch["batch_idx"],
cls=batch["cls"].squeeze(-1),
bboxes=batch["bboxes"],
paths=batch["im_file"],
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|