File size: 804 Bytes
38739b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from PIL import Image

from transformers import BeitImageProcessor, BeitForImageClassification


class Predictor:
    def __init__(self, model_id: str) -> None:
        self.processor = BeitImageProcessor.from_pretrained(model_id)
        self.model = BeitForImageClassification.from_pretrained(model_id)

    def predict(self, images: list[Image.Image]) -> list[dict[str, float]]:
        inputs = self.processor(images, return_tensors="pt")
        logits = self.model(**inputs).logits.softmax(1)  # 一応見た目が良いのでsoftmaxをかける

        results = []

        for scores in logits:
            result = {}
            for i, score in enumerate(scores):
                result[self.model.config.id2label[i]] = score.item()
            results.append(result)

        return results