Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
from thirdai import bolt, licensing | |
import os | |
import time | |
licensing.set_path("license.serialized") | |
max_posts = 5 | |
df = pd.read_csv("processed_recipes_3.csv") | |
model = bolt.UniversalDeepTransformer.load("1bn_name_ctg_keywords_4gram.bolt") | |
recipe_id_to_row_num = {} | |
for i in range(df.shape[0]): | |
recipe_id_to_row_num[df.iloc[i,0]] = i | |
INTRO_MARKDOWN = ( | |
"""# A billion parameter model, trained on a single CPU, in just 90 mins, on 522K recipes from food.com !! | |
""" | |
) | |
LIKE_TEXT = "π update LLM" | |
FEEDBACK_RECEIVED_TEXT = "π Click search for updated results" | |
SHOW_MORE = "Show more" | |
SHOW_LESS = "Show less" | |
def retrain(query, doc_id): | |
query = query.lower() | |
query.replace('\n', ' ') | |
query = ' '.join([query[i:i+4] for i in range(len(query)-3)]) | |
df = pd.DataFrame({ | |
"Name": [query], | |
"RecipeId": [str(doc_id)] | |
}) | |
filename = f"temptrain{hash(query)}{hash(doc_id)}{time.time()}.csv" | |
df.to_csv(filename) | |
prediction = None | |
while prediction != doc_id: | |
model.train(filename, epochs=1) | |
prediction = model.predict( | |
{"Name": query.replace('\n', ' ')}, | |
return_predicted_class=True) | |
os.remove(filename) | |
# sample = {"query": query.replace('\n', ' '), "id": str(doc_id)} | |
# batch = [sample] | |
# prediction = None | |
# while prediction != doc_id: | |
# model.train_batch(batch, metrics=["categorical_accuracy"]) | |
# prediction = model.predict(sample, return_predicted_class=True) | |
def search(query): | |
query = query.lower() | |
query = ' '.join([query[i:i+4] for i in range(len(query)-3)]) | |
scores = model.predict({"Name": query}) | |
#### | |
sorted_ids = scores.argsort()[-max_posts:][::-1] | |
relevant_posts = [ | |
df.iloc[pid] for pid in sorted_ids | |
] | |
#### | |
# K = min(2*max_posts, len(scores) - 1) | |
# sorted_post_ids = scores.argsort()[-K:][::-1] | |
# print(sorted_post_ids) | |
# sorted_ids = [] | |
# relevant_posts = [] | |
# count = 0 | |
# for pid in sorted_post_ids: | |
# if pid in recipe_id_to_row_num: | |
# relevant_posts.append(df.iloc[recipe_id_to_row_num[pid]]) | |
# sorted_ids.append(pid) | |
# count += 1 | |
# if count==max_posts: | |
# break | |
#### | |
header = [gr.Markdown.update(visible=True)] | |
boxes = [ | |
gr.Box.update(visible=True) | |
for _ in relevant_posts | |
] | |
titles = [ | |
gr.Markdown.update(f"## {post['Name']}") | |
for post in relevant_posts | |
] | |
toggles = [ | |
gr.Button.update( | |
visible=True, | |
value=SHOW_MORE, | |
interactive=True, | |
) | |
for _ in relevant_posts | |
] | |
matches = [ | |
gr.Button.update( | |
value=LIKE_TEXT, | |
interactive=True, | |
) | |
for _ in relevant_posts | |
] | |
bodies = [ | |
gr.HTML.update( | |
visible=False, | |
value=f"<br/>" | |
f"<h2>Description:</h2>\n{post['Description']}\n\n" | |
"<hr class='solid'>" | |
f"<h2>Ingredients:</h2>\n{post['RecipeIngredientParts']}\n\n" | |
"<br/>" | |
f"<h2>Instructions:</h2>\n{post['RecipeInstructions']}\n\n" | |
"<br/>") | |
for post in relevant_posts | |
] | |
return ( | |
header + | |
boxes + | |
titles + | |
toggles + | |
matches + | |
bodies + | |
[sorted_ids] | |
) | |
def handle_toggle(toggle): | |
if toggle == SHOW_MORE: | |
new_toggle_text = SHOW_LESS | |
visible = True | |
if toggle == SHOW_LESS: | |
new_toggle_text = SHOW_MORE | |
visible = False | |
return [ | |
gr.Button.update(new_toggle_text), | |
gr.HTML.update(visible=visible), | |
] | |
def handle_feedback(button_id: int): | |
def register_feedback(doc_ids, query): | |
retrain( | |
query=query, | |
doc_id=doc_ids[button_id] | |
) | |
return gr.Button.update( | |
value=FEEDBACK_RECEIVED_TEXT, | |
interactive=False, | |
) | |
return register_feedback | |
default_query = ( | |
"biryani lamb spicy contains cloves and red chili powder, made with ghee and hard boiled eggs, made by grinding coconut and cashew" | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown(INTRO_MARKDOWN) | |
query = gr.Textbox(value=default_query, label="Query", lines=10) | |
submit = gr.Button(value="Search") | |
header = [gr.Markdown("# Relevant Recipes", visible=False)] | |
post_boxes = [] | |
post_titles = [] | |
toggle_buttons = [] | |
match_buttons = [] | |
post_bodies = [] | |
post_ids = gr.State([]) | |
for i in range(max_posts): | |
with gr.Box(visible=False) as box: | |
post_boxes.append(box) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
title = gr.Markdown("") | |
post_titles.append(title) | |
with gr.Column(scale=1, min_width=370): | |
with gr.Row(): | |
with gr.Column(scale=3, min_width=170): | |
toggle = gr.Button(SHOW_MORE) | |
toggle_buttons.append(toggle) | |
with gr.Column(scale=1, min_width=170): | |
match = gr.Button(LIKE_TEXT) | |
match.click( | |
fn=handle_feedback(button_id=i), | |
inputs=[post_ids, query], | |
outputs=[match], | |
) | |
match_buttons.append(match) | |
body = gr.HTML("") | |
post_bodies.append(body) | |
toggle.click( | |
fn=handle_toggle, | |
inputs=[toggle], | |
outputs=[toggle, body], | |
) | |
allblocks = ( | |
header + | |
post_boxes + | |
post_titles + | |
toggle_buttons + | |
match_buttons + | |
post_bodies + | |
[post_ids] | |
) | |
query.submit( | |
fn=search, | |
inputs=[query], | |
outputs=allblocks) | |
submit.click( | |
fn=search, | |
inputs=[query], | |
outputs=allblocks) | |
demo.launch() |