waifu_aesthetics / predictor.py
p1atdev's picture
feat: add support of cafe_aesthetic
38739b3
raw
history blame
804 Bytes
from PIL import Image
from transformers import BeitImageProcessor, BeitForImageClassification
class Predictor:
def __init__(self, model_id: str) -> None:
self.processor = BeitImageProcessor.from_pretrained(model_id)
self.model = BeitForImageClassification.from_pretrained(model_id)
def predict(self, images: list[Image.Image]) -> list[dict[str, float]]:
inputs = self.processor(images, return_tensors="pt")
logits = self.model(**inputs).logits.softmax(1) # 一応見た目が良いのでsoftmaxをかける
results = []
for scores in logits:
result = {}
for i, score in enumerate(scores):
result[self.model.config.id2label[i]] = score.item()
results.append(result)
return results