license_plate_recognition / grounded_sam_plate.py
zxbsmk's picture
Duplicate from zxbsmk/license_plate_recognition
d94f42d
import warnings
warnings.filterwarnings("ignore")
from transformers import logging
logging.set_verbosity_error()
import cv2
import numpy as np
from PIL import Image
from glob import glob
from typing import Union
import termcolor
import os
import torch
import torchvision
from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor
from utils.recognize_characters import recognize_char
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# GroundingDINO config and checkpoint
GROUNDING_DINO_CONFIG_PATH = "utils/GroundingDINO_SwinB_cfg.py"
GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swinb_cogcoor.pth"
# Segment-Anything checkpoint
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "checkpoints/sam_vit_h_4b8939.pth"
# Building GroundingDINO inference model
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=DEVICE)
print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('GroundingDINO', 'green')}, model path: {termcolor.colored(GROUNDING_DINO_CHECKPOINT_PATH, 'green')}")
# Building SAM Model and SAM Predictor
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(DEVICE)
sam_predictor = SamPredictor(sam)
print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('Segment-Anything', 'green')}, model path: {termcolor.colored(SAM_CHECKPOINT_PATH, 'green')}")
# Predict classes and hyper-param for GroundingDINO
BOX_THRESHOLD = 0.25
TEXT_THRESHOLD = 0.25
NMS_THRESHOLD = 0.8
RECTIFIED_W, RECTIFIED_H = 600, 200
# Prompting SAM with detected boxes
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
sam_predictor.set_image(image)
result_masks = []
for box in xyxy:
masks, scores, logits = sam_predictor.predict(
box=box,
multimask_output=True
)
index = np.argmax(scores)
result_masks.append(masks[index])
return np.array(result_masks)
def recognize_plate(image_path: Union[np.ndarray, str], cut_ratio=0.15, save_image=False, print_probs=False):
if isinstance(image_path, str):
image = cv2.imread(image_path)
else:
image = image_path
CLASSES = ['license plate', 'sky', 'person']
# detect objects
detections = grounding_dino_model.predict_with_classes(
image=image,
classes=CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=BOX_THRESHOLD
)
# NMS post process
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
NMS_THRESHOLD
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]
# convert detections to masks
detections.mask = segment(
sam_predictor=sam_predictor,
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
xyxy=detections.xyxy
)
# filter class_id==0 results
result_masks = detections.mask[(detections.class_id==0),:,:]
result_masks = result_masks.astype(np.uint8)
# findout the min mask
min_area, min_mask = np.inf, np.zeros_like(result_masks[0])
for mask in result_masks:
area = np.sum(mask)
if area < min_area:
min_area = area
min_mask = mask
# findout minrect of min mask
minrect = cv2.minAreaRect(np.argwhere(min_mask))
box = cv2.boxPoints(minrect)
box = np.int0(box)
box[:,[0, 1]] = box[:,[1, 0]]
# draw box
cv2.drawContours(image, [box], 0, (0, 0, 255), 2)
if save_image:
os.makedirs("contours", exist_ok=True)
cv2.imwrite(f"contours/{os.path.basename(image_path)}", image)
# sort the box points by clockwise
box = box[np.argsort(box[:, 0])]
if box[0, 1] > box[1, 1]:
box[[0, 1], :] = box[[1, 0], :]
if box[2, 1] < box[3, 1]:
box[[2, 3], :] = box[[3, 2], :]
# sort the box points by side length (short-long-short-long)
if np.linalg.norm(box[0] - box[1]) > np.linalg.norm(box[1] - box[2]):
box[[1, 3], :] = box[[3, 1], :]
# cut out the license plate and rectify it
rectified_plate = cv2.warpPerspective(image, cv2.getPerspectiveTransform(box.astype(np.float32), np.array([[0, 0], [0, RECTIFIED_H], [RECTIFIED_W, RECTIFIED_H], [RECTIFIED_W, 0]], dtype=np.float32)), (RECTIFIED_W, RECTIFIED_H))
rectified_plate_flip = cv2.flip(rectified_plate, 0)
if save_image:
os.makedirs("rectified_plate", exist_ok=True)
cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}", rectified_plate)
cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}_flip.jpg", rectified_plate_flip)
# recognize characters
result = recognize_char(Image.fromarray(rectified_plate), cut_ratio=cut_ratio, print_probs=print_probs)
result['rectified_plate'] = rectified_plate
result_flip = recognize_char(Image.fromarray(rectified_plate_flip), cut_ratio=cut_ratio, print_probs=print_probs)
result_flip['rectified_plate'] = rectified_plate_flip
if len(result_flip['plate']) == 7 and result_flip["confidence"] > result["confidence"]:
result = result_flip
result['detection'] = image
return result
if __name__ == "__main__":
image_dir = "images"
image_list = glob(f"{image_dir}/*.jpg") + glob(f"{image_dir}/*.png") + glob(f"{image_dir}/*.jpeg")
for image_path in image_list:
result = recognize_plate(image_path, save_image=True, print_probs=True)
print(f"Image path: {termcolor.colored(os.path.basename(image_path), 'green')} Recognized: {termcolor.colored(result, 'blue')}")