license_plate_recognition / utils /recognize_characters.py
zxbsmk's picture
Duplicate from zxbsmk/license_plate_recognition
d94f42d
raw
history blame contribute delete
No virus
2.06 kB
import warnings
warnings.filterwarnings("ignore")
import torch
from PIL import Image
from torchvision import transforms as T
from glob import glob
import os
import re
import termcolor
from utils.iqa_recognize import recognize_chinese_char
# Load model and image transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, device=device, verbose=False).eval()
# use termcolor to print the model
print(f"Using device: {termcolor.colored(device, 'green')}, model: {termcolor.colored('parseq', 'green')}")
img_transform = T.Compose([
T.Resize(parseq.hparams.img_size, T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(0.5, 0.5)
])
filtered_chars = ["-", "_", ",", ".", ":", "!", "(", ")", " "]
def recognize_char(img: Image.Image, image_path: str=None, cut_ratio=0.15, save_image=False, print_probs=False):
if image_path is not None:
img = Image.open(image_path).convert('RGB')
left_part = img.crop((0, 0, img.size[0]*cut_ratio, img.size[1]))
if image_path is not None and save_image:
os.makedirs("cut_plate", exist_ok=True)
left_part.save(f"cut_plate/{os.path.basename(image_path)}")
left_char = recognize_chinese_char(left_part, print_probs=print_probs)
img = img.crop((img.size[0]*cut_ratio, 0, img.size[0], img.size[1]))
img = img_transform(img).unsqueeze(0)
logits = parseq(img)
pred = logits.softmax(-1)
label, confidence = parseq.tokenizer.decode(pred)
label = re.sub(f"[{''.join(filtered_chars)}]", "", label[0])
return {
"plate": left_char + label,
"confidence": float(confidence[0].data.mean()),
"chinese": left_part,
}
if __name__ == "__main__":
img_paths = glob(f"rectified_plate/*.jpg") + glob(f"rectified_plate/*.png") + glob(f"rectified_plate/*.jpeg")
for img_path in img_paths:
result = recognize_char(None, img_path, save_image=True)
print(f"Recognized: {termcolor.colored(result, 'blue')}")