Spaces:
Runtime error
Runtime error
import torch | |
import clip | |
from glob import glob | |
from PIL import Image | |
import termcolor | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
model.eval() | |
# use termcolor to print the model | |
print(f"Using device: {termcolor.colored(device, 'green')}, model: {termcolor.colored('ViT-B/32', 'green')}") | |
template_dir = "character_template" | |
char_info = { | |
"character_template/e.png": "鄂", "character_template/gui.png": "桂", | |
"character_template/hei.png": "黑", "character_template/ji.png": "冀", | |
"character_template/gui1.png": "贵", "character_template/jing.png": "京", | |
"character_template/lu.png": "鲁", "character_template/min.png": "闽", | |
"character_template/su.png": "苏", "character_template/wan.png": "皖", | |
"character_template/yu.png": "豫", "character_template/yue.png": "粤", | |
"character_template/xin.png": "新", | |
} | |
char_list = list(char_info.values()) | |
character_tensor_list = None | |
for template_path in char_info.keys(): | |
character_image = preprocess(Image.open(template_path)).unsqueeze(0).to(device) | |
if character_tensor_list is None: | |
character_tensor_list = character_image | |
else: | |
character_tensor_list = torch.cat((character_tensor_list, character_image), dim=0) | |
print(f"Support Chinese characters: {termcolor.colored(char_list, 'blue')}") | |
def recognize_chinese_char(image: Image.Image, image_path: str=None, print_probs=False): | |
if image_path is not None: | |
image = Image.open(image_path).convert('RGB') | |
image = preprocess(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_features = model.encode_image(image) | |
char_features = model.encode_image(character_tensor_list) | |
image_features = image_features / image_features.norm(dim=1, keepdim=True) | |
char_features = char_features / char_features.norm(dim=1, keepdim=True) | |
logit_scale = model.logit_scale.exp() | |
logits_per_image = logit_scale * image_features @ char_features.t() | |
logits_per_char = logits_per_image.t() | |
probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
if print_probs: | |
prob_dict = dict(zip(char_list, probs[0])) | |
print(f"Label probs: {termcolor.colored(prob_dict, 'red')}") | |
char_index = probs.argmax() | |
return char_list[char_index] | |
if __name__ == "__main__": | |
image_list = glob(f"cut_plate/left_*.jpg") + glob(f"cut_plate/left_*.png") | |
for image_path in image_list: | |
print(image_path, recognize_chinese_char(None, image_path)) |