Spaces:
Runtime error
Runtime error
import warnings | |
warnings.filterwarnings("ignore") | |
import torch | |
from PIL import Image | |
from torchvision import transforms as T | |
from glob import glob | |
import os | |
import re | |
import termcolor | |
from utils.iqa_recognize import recognize_chinese_char | |
# Load model and image transforms | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, device=device, verbose=False).eval() | |
# use termcolor to print the model | |
print(f"Using device: {termcolor.colored(device, 'green')}, model: {termcolor.colored('parseq', 'green')}") | |
img_transform = T.Compose([ | |
T.Resize(parseq.hparams.img_size, T.InterpolationMode.BICUBIC), | |
T.ToTensor(), | |
T.Normalize(0.5, 0.5) | |
]) | |
filtered_chars = ["-", "_", ",", ".", ":", "!", "(", ")", " "] | |
def recognize_char(img: Image.Image, image_path: str=None, cut_ratio=0.15, save_image=False, print_probs=False): | |
if image_path is not None: | |
img = Image.open(image_path).convert('RGB') | |
left_part = img.crop((0, 0, img.size[0]*cut_ratio, img.size[1])) | |
if image_path is not None and save_image: | |
os.makedirs("cut_plate", exist_ok=True) | |
left_part.save(f"cut_plate/{os.path.basename(image_path)}") | |
left_char = recognize_chinese_char(left_part, print_probs=print_probs) | |
img = img.crop((img.size[0]*cut_ratio, 0, img.size[0], img.size[1])) | |
img = img_transform(img).unsqueeze(0) | |
logits = parseq(img) | |
pred = logits.softmax(-1) | |
label, confidence = parseq.tokenizer.decode(pred) | |
label = re.sub(f"[{''.join(filtered_chars)}]", "", label[0]) | |
return { | |
"plate": left_char + label, | |
"confidence": float(confidence[0].data.mean()), | |
"chinese": left_part, | |
} | |
if __name__ == "__main__": | |
img_paths = glob(f"rectified_plate/*.jpg") + glob(f"rectified_plate/*.png") + glob(f"rectified_plate/*.jpeg") | |
for img_path in img_paths: | |
result = recognize_char(None, img_path, save_image=True) | |
print(f"Recognized: {termcolor.colored(result, 'blue')}") |