Spaces:
Runtime error
Runtime error
File size: 804 Bytes
38739b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from PIL import Image
from transformers import BeitImageProcessor, BeitForImageClassification
class Predictor:
def __init__(self, model_id: str) -> None:
self.processor = BeitImageProcessor.from_pretrained(model_id)
self.model = BeitForImageClassification.from_pretrained(model_id)
def predict(self, images: list[Image.Image]) -> list[dict[str, float]]:
inputs = self.processor(images, return_tensors="pt")
logits = self.model(**inputs).logits.softmax(1) # 一応見た目が良いのでsoftmaxをかける
results = []
for scores in logits:
result = {}
for i, score in enumerate(scores):
result[self.model.config.id2label[i]] = score.item()
results.append(result)
return results
|