File size: 1,798 Bytes
c50a0a6
 
38739b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cb2217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38739b3
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr

from PIL import Image

from predictor import Predictor

MODEL_SALT_BESTIMAGE = "saltacc/beit-bestimage-salt"
MODEL_CAFE_AESTHETIC = "cafeai/cafe_aesthetic"

MODEL_NAMES = [
    MODEL_SALT_BESTIMAGE,
    MODEL_CAFE_AESTHETIC,
]

models = {
    MODEL_SALT_BESTIMAGE: Predictor(MODEL_SALT_BESTIMAGE),
    MODEL_CAFE_AESTHETIC: Predictor(MODEL_CAFE_AESTHETIC),
}


def predict(image: Image.Image) -> list[dict[str, float]]:
    results = []

    for model_name in MODEL_NAMES:
        results.append(models[model_name].predict([image])[0])

    return results


def construct_ui():
    with gr.Blocks() as ui:
        gr.Markdown(
            """

        # Waifu Aesthetics

        Original model repos:

        - https://huggingface.co/saltacc/beit-bestimage-salt
        - https://huggingface.co/cafeai/cafe_aesthetic
        
        """
        )

        with gr.Column():
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(
                        label="Input Image",
                        type="pil",
                        source="upload",
                        interactive=True,
                    )

                    submit_btn = gr.Button(
                        value="Submit",
                        variant="primary",
                    )

                with gr.Column():
                    result_salt = gr.Label(
                        label=MODEL_SALT_BESTIMAGE,
                    )

                    result_cafe = gr.Label(
                        label=MODEL_CAFE_AESTHETIC,
                    )

        submit_btn.click(
            predict, inputs=[input_image], outputs=[result_salt, result_cafe]
        )

    return ui


if __name__ == "__main__":
    construct_ui().launch()