Spaces:
Runtime error
Runtime error
from PIL import Image | |
from transformers import BeitImageProcessor, BeitForImageClassification | |
class Predictor: | |
def __init__(self, model_id: str) -> None: | |
self.processor = BeitImageProcessor.from_pretrained(model_id) | |
self.model = BeitForImageClassification.from_pretrained(model_id) | |
def predict(self, images: list[Image.Image]) -> list[dict[str, float]]: | |
inputs = self.processor(images, return_tensors="pt") | |
logits = self.model(**inputs).logits.softmax(1) # 一応見た目が良いのでsoftmaxをかける | |
results = [] | |
for scores in logits: | |
result = {} | |
for i, score in enumerate(scores): | |
result[self.model.config.id2label[i]] = score.item() | |
results.append(result) | |
return results | |