import warnings warnings.filterwarnings("ignore") from transformers import logging logging.set_verbosity_error() import cv2 import numpy as np from PIL import Image from glob import glob from typing import Union import termcolor import os import torch import torchvision from groundingdino.util.inference import Model from segment_anything import sam_model_registry, SamPredictor from utils.recognize_characters import recognize_char DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # GroundingDINO config and checkpoint GROUNDING_DINO_CONFIG_PATH = "utils/GroundingDINO_SwinB_cfg.py" GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swinb_cogcoor.pth" # Segment-Anything checkpoint SAM_ENCODER_VERSION = "vit_h" SAM_CHECKPOINT_PATH = "checkpoints/sam_vit_h_4b8939.pth" # Building GroundingDINO inference model grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=DEVICE) print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('GroundingDINO', 'green')}, model path: {termcolor.colored(GROUNDING_DINO_CHECKPOINT_PATH, 'green')}") # Building SAM Model and SAM Predictor sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH) sam.to(DEVICE) sam_predictor = SamPredictor(sam) print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('Segment-Anything', 'green')}, model path: {termcolor.colored(SAM_CHECKPOINT_PATH, 'green')}") # Predict classes and hyper-param for GroundingDINO BOX_THRESHOLD = 0.25 TEXT_THRESHOLD = 0.25 NMS_THRESHOLD = 0.8 RECTIFIED_W, RECTIFIED_H = 600, 200 # Prompting SAM with detected boxes def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: sam_predictor.set_image(image) result_masks = [] for box in xyxy: masks, scores, logits = sam_predictor.predict( box=box, multimask_output=True ) index = np.argmax(scores) result_masks.append(masks[index]) return np.array(result_masks) def recognize_plate(image_path: Union[np.ndarray, str], cut_ratio=0.15, save_image=False, print_probs=False): if isinstance(image_path, str): image = cv2.imread(image_path) else: image = image_path CLASSES = ['license plate', 'sky', 'person'] # detect objects detections = grounding_dino_model.predict_with_classes( image=image, classes=CLASSES, box_threshold=BOX_THRESHOLD, text_threshold=BOX_THRESHOLD ) # NMS post process nms_idx = torchvision.ops.nms( torch.from_numpy(detections.xyxy), torch.from_numpy(detections.confidence), NMS_THRESHOLD ).numpy().tolist() detections.xyxy = detections.xyxy[nms_idx] detections.confidence = detections.confidence[nms_idx] detections.class_id = detections.class_id[nms_idx] # convert detections to masks detections.mask = segment( sam_predictor=sam_predictor, image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), xyxy=detections.xyxy ) # filter class_id==0 results result_masks = detections.mask[(detections.class_id==0),:,:] result_masks = result_masks.astype(np.uint8) # findout the min mask min_area, min_mask = np.inf, np.zeros_like(result_masks[0]) for mask in result_masks: area = np.sum(mask) if area < min_area: min_area = area min_mask = mask # findout minrect of min mask minrect = cv2.minAreaRect(np.argwhere(min_mask)) box = cv2.boxPoints(minrect) box = np.int0(box) box[:,[0, 1]] = box[:,[1, 0]] # draw box cv2.drawContours(image, [box], 0, (0, 0, 255), 2) if save_image: os.makedirs("contours", exist_ok=True) cv2.imwrite(f"contours/{os.path.basename(image_path)}", image) # sort the box points by clockwise box = box[np.argsort(box[:, 0])] if box[0, 1] > box[1, 1]: box[[0, 1], :] = box[[1, 0], :] if box[2, 1] < box[3, 1]: box[[2, 3], :] = box[[3, 2], :] # sort the box points by side length (short-long-short-long) if np.linalg.norm(box[0] - box[1]) > np.linalg.norm(box[1] - box[2]): box[[1, 3], :] = box[[3, 1], :] # cut out the license plate and rectify it rectified_plate = cv2.warpPerspective(image, cv2.getPerspectiveTransform(box.astype(np.float32), np.array([[0, 0], [0, RECTIFIED_H], [RECTIFIED_W, RECTIFIED_H], [RECTIFIED_W, 0]], dtype=np.float32)), (RECTIFIED_W, RECTIFIED_H)) rectified_plate_flip = cv2.flip(rectified_plate, 0) if save_image: os.makedirs("rectified_plate", exist_ok=True) cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}", rectified_plate) cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}_flip.jpg", rectified_plate_flip) # recognize characters result = recognize_char(Image.fromarray(rectified_plate), cut_ratio=cut_ratio, print_probs=print_probs) result['rectified_plate'] = rectified_plate result_flip = recognize_char(Image.fromarray(rectified_plate_flip), cut_ratio=cut_ratio, print_probs=print_probs) result_flip['rectified_plate'] = rectified_plate_flip if len(result_flip['plate']) == 7 and result_flip["confidence"] > result["confidence"]: result = result_flip result['detection'] = image return result if __name__ == "__main__": image_dir = "images" image_list = glob(f"{image_dir}/*.jpg") + glob(f"{image_dir}/*.png") + glob(f"{image_dir}/*.jpeg") for image_path in image_list: result = recognize_plate(image_path, save_image=True, print_probs=True) print(f"Image path: {termcolor.colored(os.path.basename(image_path), 'green')} Recognized: {termcolor.colored(result, 'blue')}")