Tharun Medini commited on
Commit
3443b2c
1 Parent(s): f1d466a

added all files

Browse files
Files changed (3) hide show
  1. app.py +214 -0
  2. license.serialized +0 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from thirdai import bolt, licensing
4
+ import os
5
+ import time
6
+
7
+
8
+ licensing.set_path("license.serialized")
9
+ max_posts = 5
10
+ df = pd.read_csv("processed_recipes_3.csv")
11
+ model = bolt.UniversalDeepTransformer.load("1bn_name_ctg_keywords.bolt")
12
+
13
+ recipe_id_to_row_num = {}
14
+
15
+ for i in range(data.shape[0]):
16
+ recipe_id_to_row_num[data.iloc[i,0]] = i
17
+
18
+
19
+ LIKE_TEXT = "👍"
20
+ FEEDBACK_RECEIVED_TEXT = "Model updated 👌"
21
+ SHOW_MORE = "Show more"
22
+ SHOW_LESS = "Show less"
23
+
24
+
25
+ def retrain(query, doc_id):
26
+ df = pd.DataFrame({
27
+ "Name": [query.replace('\n', ' ')],
28
+ "RecipeId": [str(doc_id)]
29
+ })
30
+
31
+ filename = f"temptrain{hash(query)}{hash(doc_id)}{time.time()}.csv"
32
+
33
+ df.to_csv(filename)
34
+
35
+ prediction = None
36
+
37
+ while prediction != doc_id:
38
+ model.train(filename, epochs=1)
39
+ prediction = model.predict(
40
+ {"Name": query.replace('\n', ' ')},
41
+ return_predicted_class=True)
42
+
43
+ os.remove(filename)
44
+
45
+ # sample = {"query": query.replace('\n', ' '), "id": str(doc_id)}
46
+ # batch = [sample]
47
+
48
+ # prediction = None
49
+
50
+ # while prediction != doc_id:
51
+ # model.train_batch(batch, metrics=["categorical_accuracy"])
52
+ # prediction = model.predict(sample, return_predicted_class=True)
53
+
54
+
55
+ def search(query):
56
+ scores = model.predict({"Name": query.lower()})
57
+ K = min(2*max_posts, len(scores) - 1)
58
+ sorted_post_ids = scores.argsort()[-K:][::-1]
59
+ count = 0
60
+ relevant_posts = []
61
+ for pid in sorted_post_ids:
62
+ if pid in recipe_id_to_row_num:
63
+ relevant_posts.append(df.iloc[recipe_id_to_row_num[pid]])
64
+ ##
65
+ count += 1
66
+ if count==max_posts:
67
+ break
68
+ ##
69
+ header = [gr.Markdown.update(visible=True)]
70
+ boxes = [
71
+ gr.Box.update(visible=True)
72
+ for _ in relevant_posts
73
+ ]
74
+ titles = [
75
+ gr.Markdown.update(f"## {post['Name']}")
76
+ for post in relevant_posts
77
+ ]
78
+ toggles = [
79
+ gr.Button.update(
80
+ visible=True,
81
+ value=SHOW_MORE,
82
+ interactive=True,
83
+ )
84
+ for _ in relevant_posts
85
+ ]
86
+ matches = [
87
+ gr.Button.update(
88
+ value=LIKE_TEXT,
89
+ interactive=True,
90
+ )
91
+ for _ in relevant_posts
92
+ ]
93
+ bodies = [
94
+ gr.HTML.update(
95
+ visible=False,
96
+ value=f"<br/>"
97
+ f"<h2>Description:</h2>\n{post['Description']}\n\n"
98
+ "<hr class='solid'>"
99
+ f"<h2>Ingredients:</h2>\n{post['RecipeIngredientParts']}\n\n"
100
+ "<br/>"
101
+ f"<h2>Instructions:</h2>\n{post['RecipeInstructions']}\n\n"
102
+ "<br/>")
103
+ for post in relevant_posts
104
+ ]
105
+
106
+ return (
107
+ header +
108
+ boxes +
109
+ titles +
110
+ toggles +
111
+ matches +
112
+ bodies +
113
+ [sorted_post_ids]
114
+ )
115
+
116
+
117
+ def handle_toggle(toggle):
118
+ if toggle == SHOW_MORE:
119
+ new_toggle_text = SHOW_LESS
120
+ visible = True
121
+ if toggle == SHOW_LESS:
122
+ new_toggle_text = SHOW_MORE
123
+ visible = False
124
+ return [
125
+ gr.Button.update(new_toggle_text),
126
+ gr.HTML.update(visible=visible),
127
+ ]
128
+
129
+
130
+ def handle_feedback(button_id: int):
131
+ def register_feedback(doc_ids, query):
132
+ retrain(
133
+ query=query,
134
+ doc_id=doc_ids[button_id]
135
+ )
136
+ return gr.Button.update(
137
+ value=FEEDBACK_RECEIVED_TEXT,
138
+ interactive=False,
139
+ )
140
+
141
+ return register_feedback
142
+
143
+
144
+ default_query = (
145
+ """
146
+ baby food
147
+ """
148
+ )
149
+
150
+
151
+ with gr.Blocks() as demo:
152
+ query = gr.Textbox(value=default_query, label="Question", lines=10)
153
+ submit = gr.Button(value="Search")
154
+
155
+ header = [gr.Markdown("# Relevant Recipes", visible=False)]
156
+ post_boxes = []
157
+ post_titles = []
158
+ toggle_buttons = []
159
+ match_buttons = []
160
+ post_bodies = []
161
+ post_ids = gr.State([])
162
+
163
+ for i in range(max_posts):
164
+ with gr.Box(visible=False) as box:
165
+ post_boxes.append(box)
166
+
167
+ with gr.Row():
168
+ with gr.Column(scale=5):
169
+ title = gr.Markdown("")
170
+ post_titles.append(title)
171
+ with gr.Column(scale=1, min_width=370):
172
+ with gr.Row():
173
+ with gr.Column(scale=3, min_width=170):
174
+ toggle = gr.Button(SHOW_MORE)
175
+ toggle_buttons.append(toggle)
176
+ with gr.Column(scale=1, min_width=170):
177
+ match = gr.Button(LIKE_TEXT)
178
+ match.click(
179
+ fn=handle_feedback(button_id=i),
180
+ inputs=[post_ids, query],
181
+ outputs=[match],
182
+ )
183
+ match_buttons.append(match)
184
+
185
+ body = gr.HTML("")
186
+ post_bodies.append(body)
187
+
188
+ toggle.click(
189
+ fn=handle_toggle,
190
+ inputs=[toggle],
191
+ outputs=[toggle, body],
192
+ )
193
+
194
+ allblocks = (
195
+ header +
196
+ post_boxes +
197
+ post_titles +
198
+ toggle_buttons +
199
+ match_buttons +
200
+ post_bodies +
201
+ [post_ids]
202
+ )
203
+
204
+ query.submit(
205
+ fn=search,
206
+ inputs=[query],
207
+ outputs=allblocks)
208
+ submit.click(
209
+ fn=search,
210
+ inputs=[query],
211
+ outputs=allblocks)
212
+
213
+
214
+ demo.launch()
license.serialized ADDED
Binary file (416 Bytes). View file
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ thirdai==0.6.0
2
+ pandas
3
+ gradio