Kosuke-Yamada commited on
Commit
804a590
1 Parent(s): 1a0adff

change file

Browse files
Files changed (1) hide show
  1. app.py +130 -3
app.py CHANGED
@@ -1,7 +1,134 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
7
  demo.launch()
 
1
  import gradio as gr
2
+ from google.colab import userdata
3
+ from openai import OpenAI
4
+ import gradio as gr
5
+ import requests
6
+ import os
7
+ from PIL import Image
8
+ import numpy as np
9
+ import ipadic
10
+ import MeCab
11
+ import difflib
12
+ import io
13
+ import os
14
+
15
+ api_key = os.getenvs('OPENAI_API_KEY')
16
+ client = OpenAI(api_key=api_key)
17
+
18
+ def generate_image(text):
19
+ image_path = f"/content/images/{text}.png"
20
+ if not os.path.exists(image_path):
21
+ response = client.images.generate(
22
+ model="dall-e-3",
23
+ prompt=text,
24
+ size="1024x1024",
25
+ quality="standard",
26
+ n=1,
27
+ )
28
+ image_url = response.data[0].url
29
+ image_data = requests.get(image_url).content
30
+ img = Image.open(io.BytesIO((image_data)))
31
+ img = img.resize((512, 512))
32
+ img.save(image_path)
33
+ return image_path
34
+
35
+ def calulate_similarity_score(ori_text, text):
36
+ if ori_text != text:
37
+ model_name = "text-embedding-3-small"
38
+ response = client.embeddings.create(input = [ori_text, text], model=model_name)
39
+ score = cos_sim(response.data[0].embedding, response.data[1].embedding)
40
+ score = int(round(score, 2) * 100)
41
+ if score == 100:
42
+ score = 99
43
+ else:
44
+ score = 100
45
+ return score
46
+
47
+ def cos_sim(v1, v2):
48
+ return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
49
+
50
+ def tokenize_text(text):
51
+ mecab = MeCab.Tagger(f"-Ochasen {ipadic.MECAB_ARGS}")
52
+ return [
53
+ t.split()[0]
54
+ for t in mecab.parse(text).splitlines()[:-1]
55
+ ]
56
+
57
+ def create_match_words(ori_text, text):
58
+ ori_words = tokenize_text(ori_text)
59
+ words = tokenize_text(text)
60
+ match_words = [w for w in words if w in ori_words]
61
+ return match_words
62
+
63
+ def create_hint_text(ori_text, text):
64
+ response = list(difflib.ndiff(list(text), list(ori_text)))
65
+ output = ""
66
+ for r in response:
67
+ if r[:2] == "- ":
68
+ continue
69
+ elif r[:2] == "+ ":
70
+ output += "^"
71
+ else:
72
+ output += r.strip()
73
+ return output
74
+
75
+ def update_question(selected_option):
76
+ if selected_option == "Q1":
77
+ return "/content/images/白い猫が木の上で休んでいる.png"
78
+ elif selected_option == "Q2":
79
+ return "/content/images/サメが海の中で暴れている.png"
80
+ elif selected_option == "Q3":
81
+ return "/content/images/東京スカイツリーの近くで花火大会が行われている.png"
82
+ elif selected_option == "Q4":
83
+ return "/content/images/イカとタイがいた都会.png"
84
+ elif selected_option == "Q5":
85
+ return "/content/images/赤いきつねと緑のたぬき.png"
86
+ if selected_option == "Q6":
87
+ return "/content/images/宇宙に向かってたい焼きが空を飛んでいる.png"
88
+ elif selected_option == "Q7":
89
+ return "/content/images/イケメンが海岸でクリームパンを眺めている.png"
90
+ elif selected_option == "Q8":
91
+ return "/content/images/生麦生米生卵生麦生米生卵生麦生米生卵.png"
92
+ elif selected_option == "Q9":
93
+ return "/content/images/サイバーエージェントで働く人たち.png"
94
+ elif selected_option == "Q10":
95
+ return "/content/images/柿くへば鐘が鳴るなり法隆寺.png"
96
+ elif selected_option == "Q11":
97
+ return "/content/images/鳴くよウグイス平安京.png"
98
+ else:
99
+ return "/content/images/abc.png"
100
+
101
+ def main(image, text, option):
102
+ ori_text = update_question(option).split("/")[-1].split(".png")[0]
103
+ image_path = generate_image(text)
104
+ score = calulate_similarity_score(ori_text, text)
105
+
106
+ if score < 80:
107
+ match_words = create_match_words(ori_text, text)
108
+ hint_text = "一致している単語リスト: "+" ".join(match_words)
109
+ elif 80 <= score < 100:
110
+ hint_text = "一致していない箇所: "+create_hint_text(ori_text, text)
111
+ else:
112
+ hint_text = ""
113
+ return image_path, f"{score}点", hint_text
114
+
115
+ with gr.Blocks() as demo:
116
+ with gr.Row():
117
+ with gr.Column():
118
+ gr.Markdown(
119
+ "# プロンプトを当てるゲーム \n これは表示されている画像のプロンプトを当てるゲームです。プロンプトを入力するとそれに対応した画像とスコアとヒントが表示されます。スコア100点を目指して頑張ってください! \n\nヒントは80点未満の場合は当たっている単語、80点以上の場合は足りない文字を「^」で示した文字列を表示しています。",
120
+ )
121
+ selected_option = gr.components.Radio(["Q1", "Q2", "Q3", "Q4", "Q5", "Q6", "Q7", "Q8", "Q9", "Q10", "Q11"], label="問題を選んでください!")
122
+ output_title_image = gr.components.Image(type="filepath", label="お題")
123
+ selected_option.change(update_question, inputs=[selected_option], outputs=[output_title_image])
124
 
125
+ input_text = gr.components.Textbox(lines=1, label="画像にマッチするテキストを入力して!")
126
+ submit_button = gr.Button("Submit")
127
+ with gr.Column():
128
+ output_image = gr.components.Image(type="filepath", label="生成画像")
129
+ output_score = gr.components.Textbox(lines=1, label="スコア")
130
+ output_hint_text = gr.components.Textbox(lines=1, label="ヒント")
131
 
132
+ submit_button.click(main, inputs=[output_title_image, input_text, selected_option], outputs=[output_image, output_score, output_hint_text])
133
+ demo.launch(server_port=8892)
134
  demo.launch()