import gradio as gr from PIL import Image from predictor import Predictor MODEL_SALT_BESTIMAGE = "saltacc/beit-bestimage-salt" MODEL_CAFE_AESTHETIC = "cafeai/cafe_aesthetic" MODEL_NAMES = [ MODEL_SALT_BESTIMAGE, MODEL_CAFE_AESTHETIC, ] models = { MODEL_SALT_BESTIMAGE: Predictor(MODEL_SALT_BESTIMAGE), MODEL_CAFE_AESTHETIC: Predictor(MODEL_CAFE_AESTHETIC), } def predict(image: Image.Image) -> list[dict[str, float]]: results = [] for model_name in MODEL_NAMES: results.append(models[model_name].predict([image])[0]) return results def construct_ui(): with gr.Blocks() as ui: gr.Markdown( """ # Waifu Aesthetics Original model repos: - https://huggingface.co/saltacc/beit-bestimage-salt - https://huggingface.co/cafeai/cafe_aesthetic """ ) with gr.Column(): with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image", type="pil", source="upload", interactive=True, ) submit_btn = gr.Button( value="Submit", variant="primary", ) with gr.Column(): result_salt = gr.Label( label=MODEL_SALT_BESTIMAGE, ) result_cafe = gr.Label( label=MODEL_CAFE_AESTHETIC, ) submit_btn.click( predict, inputs=[input_image], outputs=[result_salt, result_cafe] ) return ui if __name__ == "__main__": construct_ui().launch()