FoodUDT-1B / app.py
tharunkr24's picture
Update app.py
3fb7896
raw
history blame
5.39 kB
import gradio as gr
import pandas as pd
from thirdai import bolt, licensing
import os
import time
licensing.set_path("license.serialized")
max_posts = 5
df = pd.read_csv("processed_recipes_3.csv")
model = bolt.UniversalDeepTransformer.load("1bn_name_ctg_keywords.bolt")
recipe_id_to_row_num = {}
for i in range(df.shape[0]):
recipe_id_to_row_num[df.iloc[i,0]] = i
LIKE_TEXT = "πŸ‘"
FEEDBACK_RECEIVED_TEXT = "Model updated πŸ‘Œ"
SHOW_MORE = "Show more"
SHOW_LESS = "Show less"
def retrain(query, doc_id):
df = pd.DataFrame({
"Name": [query.replace('\n', ' ')],
"RecipeId": [str(doc_id)]
})
filename = f"temptrain{hash(query)}{hash(doc_id)}{time.time()}.csv"
df.to_csv(filename)
prediction = None
while prediction != doc_id:
model.train(filename, epochs=1)
prediction = model.predict(
{"Name": query.replace('\n', ' ')},
return_predicted_class=True)
os.remove(filename)
# sample = {"query": query.replace('\n', ' '), "id": str(doc_id)}
# batch = [sample]
# prediction = None
# while prediction != doc_id:
# model.train_batch(batch, metrics=["categorical_accuracy"])
# prediction = model.predict(sample, return_predicted_class=True)
def search(query):
scores = model.predict({"Name": query.lower()})
K = min(2*max_posts, len(scores) - 1)
sorted_post_ids = scores.argsort()[-K:][::-1]
count = 0
relevant_posts = []
for pid in sorted_post_ids:
if pid in recipe_id_to_row_num:
relevant_posts.append(df.iloc[recipe_id_to_row_num[pid]])
##
count += 1
if count==max_posts:
break
##
header = [gr.Markdown.update(visible=True)]
boxes = [
gr.Box.update(visible=True)
for _ in relevant_posts
]
titles = [
gr.Markdown.update(f"## {post['Name']}")
for post in relevant_posts
]
toggles = [
gr.Button.update(
visible=True,
value=SHOW_MORE,
interactive=True,
)
for _ in relevant_posts
]
matches = [
gr.Button.update(
value=LIKE_TEXT,
interactive=True,
)
for _ in relevant_posts
]
bodies = [
gr.HTML.update(
visible=False,
value=f"<br/>"
f"<h2>Description:</h2>\n{post['Description']}\n\n"
"<hr class='solid'>"
f"<h2>Ingredients:</h2>\n{post['RecipeIngredientParts']}\n\n"
"<br/>"
f"<h2>Instructions:</h2>\n{post['RecipeInstructions']}\n\n"
"<br/>")
for post in relevant_posts
]
return (
header +
boxes +
titles +
toggles +
matches +
bodies +
[sorted_post_ids]
)
def handle_toggle(toggle):
if toggle == SHOW_MORE:
new_toggle_text = SHOW_LESS
visible = True
if toggle == SHOW_LESS:
new_toggle_text = SHOW_MORE
visible = False
return [
gr.Button.update(new_toggle_text),
gr.HTML.update(visible=visible),
]
def handle_feedback(button_id: int):
def register_feedback(doc_ids, query):
retrain(
query=query,
doc_id=doc_ids[button_id]
)
return gr.Button.update(
value=FEEDBACK_RECEIVED_TEXT,
interactive=False,
)
return register_feedback
default_query = (
"""
baby food
"""
)
with gr.Blocks() as demo:
query = gr.Textbox(value=default_query, label="Query", lines=10)
submit = gr.Button(value="Search")
header = [gr.Markdown("# Relevant Recipes", visible=False)]
post_boxes = []
post_titles = []
toggle_buttons = []
match_buttons = []
post_bodies = []
post_ids = gr.State([])
for i in range(max_posts):
with gr.Box(visible=False) as box:
post_boxes.append(box)
with gr.Row():
with gr.Column(scale=5):
title = gr.Markdown("")
post_titles.append(title)
with gr.Column(scale=1, min_width=370):
with gr.Row():
with gr.Column(scale=3, min_width=170):
toggle = gr.Button(SHOW_MORE)
toggle_buttons.append(toggle)
with gr.Column(scale=1, min_width=170):
match = gr.Button(LIKE_TEXT)
match.click(
fn=handle_feedback(button_id=i),
inputs=[post_ids, query],
outputs=[match],
)
match_buttons.append(match)
body = gr.HTML("")
post_bodies.append(body)
toggle.click(
fn=handle_toggle,
inputs=[toggle],
outputs=[toggle, body],
)
allblocks = (
header +
post_boxes +
post_titles +
toggle_buttons +
match_buttons +
post_bodies +
[post_ids]
)
query.submit(
fn=search,
inputs=[query],
outputs=allblocks)
submit.click(
fn=search,
inputs=[query],
outputs=allblocks)
demo.launch()