Spaces:
Runtime error
Runtime error
Tharun Medini
commited on
Commit
•
3443b2c
1
Parent(s):
f1d466a
added all files
Browse files- app.py +214 -0
- license.serialized +0 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
from thirdai import bolt, licensing
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
|
8 |
+
licensing.set_path("license.serialized")
|
9 |
+
max_posts = 5
|
10 |
+
df = pd.read_csv("processed_recipes_3.csv")
|
11 |
+
model = bolt.UniversalDeepTransformer.load("1bn_name_ctg_keywords.bolt")
|
12 |
+
|
13 |
+
recipe_id_to_row_num = {}
|
14 |
+
|
15 |
+
for i in range(data.shape[0]):
|
16 |
+
recipe_id_to_row_num[data.iloc[i,0]] = i
|
17 |
+
|
18 |
+
|
19 |
+
LIKE_TEXT = "👍"
|
20 |
+
FEEDBACK_RECEIVED_TEXT = "Model updated 👌"
|
21 |
+
SHOW_MORE = "Show more"
|
22 |
+
SHOW_LESS = "Show less"
|
23 |
+
|
24 |
+
|
25 |
+
def retrain(query, doc_id):
|
26 |
+
df = pd.DataFrame({
|
27 |
+
"Name": [query.replace('\n', ' ')],
|
28 |
+
"RecipeId": [str(doc_id)]
|
29 |
+
})
|
30 |
+
|
31 |
+
filename = f"temptrain{hash(query)}{hash(doc_id)}{time.time()}.csv"
|
32 |
+
|
33 |
+
df.to_csv(filename)
|
34 |
+
|
35 |
+
prediction = None
|
36 |
+
|
37 |
+
while prediction != doc_id:
|
38 |
+
model.train(filename, epochs=1)
|
39 |
+
prediction = model.predict(
|
40 |
+
{"Name": query.replace('\n', ' ')},
|
41 |
+
return_predicted_class=True)
|
42 |
+
|
43 |
+
os.remove(filename)
|
44 |
+
|
45 |
+
# sample = {"query": query.replace('\n', ' '), "id": str(doc_id)}
|
46 |
+
# batch = [sample]
|
47 |
+
|
48 |
+
# prediction = None
|
49 |
+
|
50 |
+
# while prediction != doc_id:
|
51 |
+
# model.train_batch(batch, metrics=["categorical_accuracy"])
|
52 |
+
# prediction = model.predict(sample, return_predicted_class=True)
|
53 |
+
|
54 |
+
|
55 |
+
def search(query):
|
56 |
+
scores = model.predict({"Name": query.lower()})
|
57 |
+
K = min(2*max_posts, len(scores) - 1)
|
58 |
+
sorted_post_ids = scores.argsort()[-K:][::-1]
|
59 |
+
count = 0
|
60 |
+
relevant_posts = []
|
61 |
+
for pid in sorted_post_ids:
|
62 |
+
if pid in recipe_id_to_row_num:
|
63 |
+
relevant_posts.append(df.iloc[recipe_id_to_row_num[pid]])
|
64 |
+
##
|
65 |
+
count += 1
|
66 |
+
if count==max_posts:
|
67 |
+
break
|
68 |
+
##
|
69 |
+
header = [gr.Markdown.update(visible=True)]
|
70 |
+
boxes = [
|
71 |
+
gr.Box.update(visible=True)
|
72 |
+
for _ in relevant_posts
|
73 |
+
]
|
74 |
+
titles = [
|
75 |
+
gr.Markdown.update(f"## {post['Name']}")
|
76 |
+
for post in relevant_posts
|
77 |
+
]
|
78 |
+
toggles = [
|
79 |
+
gr.Button.update(
|
80 |
+
visible=True,
|
81 |
+
value=SHOW_MORE,
|
82 |
+
interactive=True,
|
83 |
+
)
|
84 |
+
for _ in relevant_posts
|
85 |
+
]
|
86 |
+
matches = [
|
87 |
+
gr.Button.update(
|
88 |
+
value=LIKE_TEXT,
|
89 |
+
interactive=True,
|
90 |
+
)
|
91 |
+
for _ in relevant_posts
|
92 |
+
]
|
93 |
+
bodies = [
|
94 |
+
gr.HTML.update(
|
95 |
+
visible=False,
|
96 |
+
value=f"<br/>"
|
97 |
+
f"<h2>Description:</h2>\n{post['Description']}\n\n"
|
98 |
+
"<hr class='solid'>"
|
99 |
+
f"<h2>Ingredients:</h2>\n{post['RecipeIngredientParts']}\n\n"
|
100 |
+
"<br/>"
|
101 |
+
f"<h2>Instructions:</h2>\n{post['RecipeInstructions']}\n\n"
|
102 |
+
"<br/>")
|
103 |
+
for post in relevant_posts
|
104 |
+
]
|
105 |
+
|
106 |
+
return (
|
107 |
+
header +
|
108 |
+
boxes +
|
109 |
+
titles +
|
110 |
+
toggles +
|
111 |
+
matches +
|
112 |
+
bodies +
|
113 |
+
[sorted_post_ids]
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def handle_toggle(toggle):
|
118 |
+
if toggle == SHOW_MORE:
|
119 |
+
new_toggle_text = SHOW_LESS
|
120 |
+
visible = True
|
121 |
+
if toggle == SHOW_LESS:
|
122 |
+
new_toggle_text = SHOW_MORE
|
123 |
+
visible = False
|
124 |
+
return [
|
125 |
+
gr.Button.update(new_toggle_text),
|
126 |
+
gr.HTML.update(visible=visible),
|
127 |
+
]
|
128 |
+
|
129 |
+
|
130 |
+
def handle_feedback(button_id: int):
|
131 |
+
def register_feedback(doc_ids, query):
|
132 |
+
retrain(
|
133 |
+
query=query,
|
134 |
+
doc_id=doc_ids[button_id]
|
135 |
+
)
|
136 |
+
return gr.Button.update(
|
137 |
+
value=FEEDBACK_RECEIVED_TEXT,
|
138 |
+
interactive=False,
|
139 |
+
)
|
140 |
+
|
141 |
+
return register_feedback
|
142 |
+
|
143 |
+
|
144 |
+
default_query = (
|
145 |
+
"""
|
146 |
+
baby food
|
147 |
+
"""
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
with gr.Blocks() as demo:
|
152 |
+
query = gr.Textbox(value=default_query, label="Question", lines=10)
|
153 |
+
submit = gr.Button(value="Search")
|
154 |
+
|
155 |
+
header = [gr.Markdown("# Relevant Recipes", visible=False)]
|
156 |
+
post_boxes = []
|
157 |
+
post_titles = []
|
158 |
+
toggle_buttons = []
|
159 |
+
match_buttons = []
|
160 |
+
post_bodies = []
|
161 |
+
post_ids = gr.State([])
|
162 |
+
|
163 |
+
for i in range(max_posts):
|
164 |
+
with gr.Box(visible=False) as box:
|
165 |
+
post_boxes.append(box)
|
166 |
+
|
167 |
+
with gr.Row():
|
168 |
+
with gr.Column(scale=5):
|
169 |
+
title = gr.Markdown("")
|
170 |
+
post_titles.append(title)
|
171 |
+
with gr.Column(scale=1, min_width=370):
|
172 |
+
with gr.Row():
|
173 |
+
with gr.Column(scale=3, min_width=170):
|
174 |
+
toggle = gr.Button(SHOW_MORE)
|
175 |
+
toggle_buttons.append(toggle)
|
176 |
+
with gr.Column(scale=1, min_width=170):
|
177 |
+
match = gr.Button(LIKE_TEXT)
|
178 |
+
match.click(
|
179 |
+
fn=handle_feedback(button_id=i),
|
180 |
+
inputs=[post_ids, query],
|
181 |
+
outputs=[match],
|
182 |
+
)
|
183 |
+
match_buttons.append(match)
|
184 |
+
|
185 |
+
body = gr.HTML("")
|
186 |
+
post_bodies.append(body)
|
187 |
+
|
188 |
+
toggle.click(
|
189 |
+
fn=handle_toggle,
|
190 |
+
inputs=[toggle],
|
191 |
+
outputs=[toggle, body],
|
192 |
+
)
|
193 |
+
|
194 |
+
allblocks = (
|
195 |
+
header +
|
196 |
+
post_boxes +
|
197 |
+
post_titles +
|
198 |
+
toggle_buttons +
|
199 |
+
match_buttons +
|
200 |
+
post_bodies +
|
201 |
+
[post_ids]
|
202 |
+
)
|
203 |
+
|
204 |
+
query.submit(
|
205 |
+
fn=search,
|
206 |
+
inputs=[query],
|
207 |
+
outputs=allblocks)
|
208 |
+
submit.click(
|
209 |
+
fn=search,
|
210 |
+
inputs=[query],
|
211 |
+
outputs=allblocks)
|
212 |
+
|
213 |
+
|
214 |
+
demo.launch()
|
license.serialized
ADDED
Binary file (416 Bytes). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
thirdai==0.6.0
|
2 |
+
pandas
|
3 |
+
gradio
|