FoodUDT-1B / app.py
tharunkr24's picture
Update app.py
37dfef1
raw
history blame
6.13 kB
import gradio as gr
import pandas as pd
from thirdai import bolt, licensing
import os
import time
licensing.set_path("license.serialized")
max_posts = 5
df = pd.read_csv("processed_recipes_3.csv")
model = bolt.UniversalDeepTransformer.load("1bn_name_ctg_keywords_4gram.bolt")
recipe_id_to_row_num = {}
for i in range(df.shape[0]):
recipe_id_to_row_num[df.iloc[i,0]] = i
INTRO_MARKDOWN = (
"""# A billion parameter model, trained on a single CPU, in just 90 mins, on 522K recipes from food.com !!
"""
)
LIKE_TEXT = "πŸ‘ update LLM"
FEEDBACK_RECEIVED_TEXT = "πŸ‘Œ Click search for updated results"
SHOW_MORE = "Show more"
SHOW_LESS = "Show less"
def retrain(query, doc_id):
query = query.lower()
query.replace('\n', ' ')
query = ' '.join([query[i:i+4] for i in range(len(query)-3)])
df = pd.DataFrame({
"Name": [query],
"RecipeId": [str(doc_id)]
})
filename = f"temptrain{hash(query)}{hash(doc_id)}{time.time()}.csv"
df.to_csv(filename)
prediction = None
while prediction != doc_id:
model.train(filename, epochs=1)
prediction = model.predict(
{"Name": query.replace('\n', ' ')},
return_predicted_class=True)
os.remove(filename)
# sample = {"query": query.replace('\n', ' '), "id": str(doc_id)}
# batch = [sample]
# prediction = None
# while prediction != doc_id:
# model.train_batch(batch, metrics=["categorical_accuracy"])
# prediction = model.predict(sample, return_predicted_class=True)
def search(query):
query = query.lower()
query = ' '.join([query[i:i+4] for i in range(len(query)-3)])
scores = model.predict({"Name": query})
####
sorted_ids = scores.argsort()[-max_posts:][::-1]
relevant_posts = [
df.iloc[pid] for pid in sorted_ids
]
####
# K = min(2*max_posts, len(scores) - 1)
# sorted_post_ids = scores.argsort()[-K:][::-1]
# print(sorted_post_ids)
# sorted_ids = []
# relevant_posts = []
# count = 0
# for pid in sorted_post_ids:
# if pid in recipe_id_to_row_num:
# relevant_posts.append(df.iloc[recipe_id_to_row_num[pid]])
# sorted_ids.append(pid)
# count += 1
# if count==max_posts:
# break
####
header = [gr.Markdown.update(visible=True)]
boxes = [
gr.Box.update(visible=True)
for _ in relevant_posts
]
titles = [
gr.Markdown.update(f"## {post['Name']}")
for post in relevant_posts
]
toggles = [
gr.Button.update(
visible=True,
value=SHOW_MORE,
interactive=True,
)
for _ in relevant_posts
]
matches = [
gr.Button.update(
value=LIKE_TEXT,
interactive=True,
)
for _ in relevant_posts
]
bodies = [
gr.HTML.update(
visible=False,
value=f"<br/>"
f"<h2>Description:</h2>\n{post['Description']}\n\n"
"<hr class='solid'>"
f"<h2>Ingredients:</h2>\n{post['RecipeIngredientParts']}\n\n"
"<br/>"
f"<h2>Instructions:</h2>\n{post['RecipeInstructions']}\n\n"
"<br/>")
for post in relevant_posts
]
return (
header +
boxes +
titles +
toggles +
matches +
bodies +
[sorted_ids]
)
def handle_toggle(toggle):
if toggle == SHOW_MORE:
new_toggle_text = SHOW_LESS
visible = True
if toggle == SHOW_LESS:
new_toggle_text = SHOW_MORE
visible = False
return [
gr.Button.update(new_toggle_text),
gr.HTML.update(visible=visible),
]
def handle_feedback(button_id: int):
def register_feedback(doc_ids, query):
retrain(
query=query,
doc_id=doc_ids[button_id]
)
return gr.Button.update(
value=FEEDBACK_RECEIVED_TEXT,
interactive=False,
)
return register_feedback
default_query = (
"biryani lamb spicy contains cloves and red chili powder, made with ghee and hard boiled eggs, made by grinding coconut and cashew"
)
with gr.Blocks() as demo:
gr.Markdown(INTRO_MARKDOWN)
query = gr.Textbox(value=default_query, label="Query", lines=10)
submit = gr.Button(value="Search")
header = [gr.Markdown("# Relevant Recipes", visible=False)]
post_boxes = []
post_titles = []
toggle_buttons = []
match_buttons = []
post_bodies = []
post_ids = gr.State([])
for i in range(max_posts):
with gr.Box(visible=False) as box:
post_boxes.append(box)
with gr.Row():
with gr.Column(scale=5):
title = gr.Markdown("")
post_titles.append(title)
with gr.Column(scale=1, min_width=370):
with gr.Row():
with gr.Column(scale=3, min_width=170):
toggle = gr.Button(SHOW_MORE)
toggle_buttons.append(toggle)
with gr.Column(scale=1, min_width=170):
match = gr.Button(LIKE_TEXT)
match.click(
fn=handle_feedback(button_id=i),
inputs=[post_ids, query],
outputs=[match],
)
match_buttons.append(match)
body = gr.HTML("")
post_bodies.append(body)
toggle.click(
fn=handle_toggle,
inputs=[toggle],
outputs=[toggle, body],
)
allblocks = (
header +
post_boxes +
post_titles +
toggle_buttons +
match_buttons +
post_bodies +
[post_ids]
)
query.submit(
fn=search,
inputs=[query],
outputs=allblocks)
submit.click(
fn=search,
inputs=[query],
outputs=allblocks)
demo.launch()