Spaces:
Runtime error
Runtime error
import warnings | |
warnings.filterwarnings("ignore") | |
from transformers import logging | |
logging.set_verbosity_error() | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from glob import glob | |
from typing import Union | |
import termcolor | |
import os | |
import torch | |
import torchvision | |
from groundingdino.util.inference import Model | |
from segment_anything import sam_model_registry, SamPredictor | |
from utils.recognize_characters import recognize_char | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# GroundingDINO config and checkpoint | |
GROUNDING_DINO_CONFIG_PATH = "utils/GroundingDINO_SwinB_cfg.py" | |
GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swinb_cogcoor.pth" | |
# Segment-Anything checkpoint | |
SAM_ENCODER_VERSION = "vit_h" | |
SAM_CHECKPOINT_PATH = "checkpoints/sam_vit_h_4b8939.pth" | |
# Building GroundingDINO inference model | |
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=DEVICE) | |
print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('GroundingDINO', 'green')}, model path: {termcolor.colored(GROUNDING_DINO_CHECKPOINT_PATH, 'green')}") | |
# Building SAM Model and SAM Predictor | |
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH) | |
sam.to(DEVICE) | |
sam_predictor = SamPredictor(sam) | |
print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('Segment-Anything', 'green')}, model path: {termcolor.colored(SAM_CHECKPOINT_PATH, 'green')}") | |
# Predict classes and hyper-param for GroundingDINO | |
BOX_THRESHOLD = 0.25 | |
TEXT_THRESHOLD = 0.25 | |
NMS_THRESHOLD = 0.8 | |
RECTIFIED_W, RECTIFIED_H = 600, 200 | |
# Prompting SAM with detected boxes | |
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: | |
sam_predictor.set_image(image) | |
result_masks = [] | |
for box in xyxy: | |
masks, scores, logits = sam_predictor.predict( | |
box=box, | |
multimask_output=True | |
) | |
index = np.argmax(scores) | |
result_masks.append(masks[index]) | |
return np.array(result_masks) | |
def recognize_plate(image_path: Union[np.ndarray, str], cut_ratio=0.15, save_image=False, print_probs=False): | |
if isinstance(image_path, str): | |
image = cv2.imread(image_path) | |
else: | |
image = image_path | |
CLASSES = ['license plate', 'sky', 'person'] | |
# detect objects | |
detections = grounding_dino_model.predict_with_classes( | |
image=image, | |
classes=CLASSES, | |
box_threshold=BOX_THRESHOLD, | |
text_threshold=BOX_THRESHOLD | |
) | |
# NMS post process | |
nms_idx = torchvision.ops.nms( | |
torch.from_numpy(detections.xyxy), | |
torch.from_numpy(detections.confidence), | |
NMS_THRESHOLD | |
).numpy().tolist() | |
detections.xyxy = detections.xyxy[nms_idx] | |
detections.confidence = detections.confidence[nms_idx] | |
detections.class_id = detections.class_id[nms_idx] | |
# convert detections to masks | |
detections.mask = segment( | |
sam_predictor=sam_predictor, | |
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), | |
xyxy=detections.xyxy | |
) | |
# filter class_id==0 results | |
result_masks = detections.mask[(detections.class_id==0),:,:] | |
result_masks = result_masks.astype(np.uint8) | |
# findout the min mask | |
min_area, min_mask = np.inf, np.zeros_like(result_masks[0]) | |
for mask in result_masks: | |
area = np.sum(mask) | |
if area < min_area: | |
min_area = area | |
min_mask = mask | |
# findout minrect of min mask | |
minrect = cv2.minAreaRect(np.argwhere(min_mask)) | |
box = cv2.boxPoints(minrect) | |
box = np.int0(box) | |
box[:,[0, 1]] = box[:,[1, 0]] | |
# draw box | |
cv2.drawContours(image, [box], 0, (0, 0, 255), 2) | |
if save_image: | |
os.makedirs("contours", exist_ok=True) | |
cv2.imwrite(f"contours/{os.path.basename(image_path)}", image) | |
# sort the box points by clockwise | |
box = box[np.argsort(box[:, 0])] | |
if box[0, 1] > box[1, 1]: | |
box[[0, 1], :] = box[[1, 0], :] | |
if box[2, 1] < box[3, 1]: | |
box[[2, 3], :] = box[[3, 2], :] | |
# sort the box points by side length (short-long-short-long) | |
if np.linalg.norm(box[0] - box[1]) > np.linalg.norm(box[1] - box[2]): | |
box[[1, 3], :] = box[[3, 1], :] | |
# cut out the license plate and rectify it | |
rectified_plate = cv2.warpPerspective(image, cv2.getPerspectiveTransform(box.astype(np.float32), np.array([[0, 0], [0, RECTIFIED_H], [RECTIFIED_W, RECTIFIED_H], [RECTIFIED_W, 0]], dtype=np.float32)), (RECTIFIED_W, RECTIFIED_H)) | |
rectified_plate_flip = cv2.flip(rectified_plate, 0) | |
if save_image: | |
os.makedirs("rectified_plate", exist_ok=True) | |
cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}", rectified_plate) | |
cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}_flip.jpg", rectified_plate_flip) | |
# recognize characters | |
result = recognize_char(Image.fromarray(rectified_plate), cut_ratio=cut_ratio, print_probs=print_probs) | |
result['rectified_plate'] = rectified_plate | |
result_flip = recognize_char(Image.fromarray(rectified_plate_flip), cut_ratio=cut_ratio, print_probs=print_probs) | |
result_flip['rectified_plate'] = rectified_plate_flip | |
if len(result_flip['plate']) == 7 and result_flip["confidence"] > result["confidence"]: | |
result = result_flip | |
result['detection'] = image | |
return result | |
if __name__ == "__main__": | |
image_dir = "images" | |
image_list = glob(f"{image_dir}/*.jpg") + glob(f"{image_dir}/*.png") + glob(f"{image_dir}/*.jpeg") | |
for image_path in image_list: | |
result = recognize_plate(image_path, save_image=True, print_probs=True) | |
print(f"Image path: {termcolor.colored(os.path.basename(image_path), 'green')} Recognized: {termcolor.colored(result, 'blue')}") | |