File size: 2,565 Bytes
d94f42d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import clip
from glob import glob
from PIL import Image
import termcolor

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
model.eval()
# use termcolor to print the model
print(f"Using device: {termcolor.colored(device, 'green')}, model: {termcolor.colored('ViT-B/32', 'green')}")


template_dir = "character_template"
char_info = {
    "character_template/e.png": "鄂", "character_template/gui.png": "桂",
    "character_template/hei.png": "黑", "character_template/ji.png": "冀",
    "character_template/gui1.png": "贵", "character_template/jing.png": "京",
    "character_template/lu.png": "鲁", "character_template/min.png": "闽",
    "character_template/su.png": "苏", "character_template/wan.png": "皖",
    "character_template/yu.png": "豫", "character_template/yue.png": "粤",
    "character_template/xin.png": "新",
}
char_list = list(char_info.values())
character_tensor_list = None
for template_path in char_info.keys():
    character_image = preprocess(Image.open(template_path)).unsqueeze(0).to(device)
    if character_tensor_list is None:
        character_tensor_list = character_image
    else:
        character_tensor_list = torch.cat((character_tensor_list, character_image), dim=0)
print(f"Support Chinese characters: {termcolor.colored(char_list, 'blue')}")


def recognize_chinese_char(image: Image.Image, image_path: str=None, print_probs=False):
    if image_path is not None:
        image = Image.open(image_path).convert('RGB')

    image = preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        char_features = model.encode_image(character_tensor_list)
        
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        char_features = char_features / char_features.norm(dim=1, keepdim=True)

        logit_scale = model.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ char_features.t()
        logits_per_char = logits_per_image.t()
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    if print_probs:
        prob_dict = dict(zip(char_list, probs[0]))
        print(f"Label probs: {termcolor.colored(prob_dict, 'red')}")
    char_index = probs.argmax()
    return char_list[char_index]



if __name__ == "__main__":
    image_list = glob(f"cut_plate/left_*.jpg") + glob(f"cut_plate/left_*.png")
    for image_path in image_list:
        print(image_path, recognize_chinese_char(None, image_path))