p1atdev's picture
chore: add descriptions
4cb2217
raw
history blame contribute delete
No virus
1.8 kB
import gradio as gr
from PIL import Image
from predictor import Predictor
MODEL_SALT_BESTIMAGE = "saltacc/beit-bestimage-salt"
MODEL_CAFE_AESTHETIC = "cafeai/cafe_aesthetic"
MODEL_NAMES = [
MODEL_SALT_BESTIMAGE,
MODEL_CAFE_AESTHETIC,
]
models = {
MODEL_SALT_BESTIMAGE: Predictor(MODEL_SALT_BESTIMAGE),
MODEL_CAFE_AESTHETIC: Predictor(MODEL_CAFE_AESTHETIC),
}
def predict(image: Image.Image) -> list[dict[str, float]]:
results = []
for model_name in MODEL_NAMES:
results.append(models[model_name].predict([image])[0])
return results
def construct_ui():
with gr.Blocks() as ui:
gr.Markdown(
"""
# Waifu Aesthetics
Original model repos:
- https://huggingface.co/saltacc/beit-bestimage-salt
- https://huggingface.co/cafeai/cafe_aesthetic
"""
)
with gr.Column():
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="pil",
source="upload",
interactive=True,
)
submit_btn = gr.Button(
value="Submit",
variant="primary",
)
with gr.Column():
result_salt = gr.Label(
label=MODEL_SALT_BESTIMAGE,
)
result_cafe = gr.Label(
label=MODEL_CAFE_AESTHETIC,
)
submit_btn.click(
predict, inputs=[input_image], outputs=[result_salt, result_cafe]
)
return ui
if __name__ == "__main__":
construct_ui().launch()