from openai import OpenAI import gradio as gr import requests from PIL import Image import numpy as np import ipadic import MeCab import difflib import io import os client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) def generate_image(text): image_path = f"./{text}.png" if not os.path.exists(image_path): response = client.images.generate( model="dall-e-3", prompt=text, size="1024x1024", quality="standard", n=1, ) image_url = response.data[0].url image_data = requests.get(image_url).content img = Image.open(io.BytesIO((image_data))) img = img.resize((512, 512)) img.save(image_path) return image_path def cos_sim(v1, v2): return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) def calculate_similarity_score(ori_text, text): if ori_text != text: response = client.embeddings.create( input=[ori_text, text], model="text-embedding-3-small" ) score = cos_sim(response.data[0].embedding, response.data[1].embedding) score = int(round(score, 2) * 100) score = 99 if score == 100 else score else: score = 100 return score def tokenize_text(text): mecab = MeCab.Tagger(f"-Ochasen {ipadic.MECAB_ARGS}") return [t.split()[0] for t in mecab.parse(text).splitlines()[:-1]] def create_match_words(ori_text, text): ori_words = tokenize_text(ori_text) words = tokenize_text(text) match_words = [w for w in words if w in ori_words] return match_words def create_hint_text(ori_text, text): response = list(difflib.ndiff(list(text), list(ori_text))) output = "" for r in response: if r[:2] == "- ": continue elif r[:2] == "+ ": output += "X" else: output += r.strip() return output def update_question(option): answer = os.getenv(option) return f"./{answer}.png" def main(text, option): ori_text = os.getenv(option) image_path = generate_image(text) score = calculate_similarity_score(ori_text, text) if score < 80: match_words = create_match_words(ori_text, text) hint_text = "一致している単語リスト: " + " ".join(match_words) elif 80 <= score < 100: hint_text = "一致していない箇所: " + create_hint_text(ori_text, text) else: hint_text = "" return image_path, f"{score}点", hint_text def auth(user_name, password): if user_name == os.getenv("USER_NAME") and password == os.getenv("PASSWORD"): return True else: return False questions = ["Q1", "Q2", "Q3"] for q in questions: image_path = generate_image(os.getenv(q)) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): gr.Markdown( "# プロンプトを当てるゲーム \n これは表示されている画像のプロンプトを当てるゲームです。プロンプトを入力するとそれに対応した画像とスコアとヒントが表示されます。スコア100点を目指して頑張ってください! \n\nヒントは80点未満の場合は当たっている単語(順番は合っているとは限らない)、80点以上の場合は足りない文字を「X」で示した文字列を表示しています。", ) option = gr.components.Radio( ["Q1", "Q2", "Q3"], label="問題を選んでください!" ) output_title_image = gr.components.Image(type="filepath", label="お題") option.change( update_question, inputs=[option], outputs=[output_title_image] ) input_text = gr.components.Textbox( lines=1, label="画像にマッチするテキストを入力して!" ) submit_button = gr.Button("Submit!") with gr.Column(): output_image = gr.components.Image(type="filepath", label="生成画像") output_score = gr.components.Textbox(lines=1, label="スコア") output_hint_text = gr.components.Textbox(lines=1, label="ヒント") submit_button.click( main, inputs=[input_text, option], outputs=[output_image, output_score, output_hint_text], ) demo.launch(auth=auth)