import gradio as gr import pandas as pd from thirdai import bolt, licensing import os import time licensing.set_path("license.serialized") max_posts = 5 df = pd.read_csv("processed_recipes_3.csv") model = bolt.UniversalDeepTransformer.load("1bn_name_ctg_keywords_4gram.bolt") recipe_id_to_row_num = {} for i in range(df.shape[0]): recipe_id_to_row_num[df.iloc[i,0]] = i INTRO_MARKDOWN = ( """# A billion parameter model, trained on a single CPU, in just 90 mins, on 522K recipes from food.com !! """ ) LIKE_TEXT = "👍 update LLM" FEEDBACK_RECEIVED_TEXT = "👌 Click search for updated results" SHOW_MORE = "Show more" SHOW_LESS = "Show less" def retrain(query, doc_id): query = query.lower() query.replace('\n', ' ') query = ' '.join([query[i:i+4] for i in range(len(query)-3)]) df = pd.DataFrame({ "Name": [query], "RecipeId": [str(doc_id)] }) filename = f"temptrain{hash(query)}{hash(doc_id)}{time.time()}.csv" df.to_csv(filename) prediction = None while prediction != doc_id: model.train(filename, epochs=1) prediction = model.predict( {"Name": query.replace('\n', ' ')}, return_predicted_class=True) os.remove(filename) # sample = {"query": query.replace('\n', ' '), "id": str(doc_id)} # batch = [sample] # prediction = None # while prediction != doc_id: # model.train_batch(batch, metrics=["categorical_accuracy"]) # prediction = model.predict(sample, return_predicted_class=True) def search(query): query = query.lower() query = ' '.join([query[i:i+4] for i in range(len(query)-3)]) scores = model.predict({"Name": query}) #### sorted_ids = scores.argsort()[-max_posts:][::-1] relevant_posts = [ df.iloc[pid] for pid in sorted_ids ] #### # K = min(2*max_posts, len(scores) - 1) # sorted_post_ids = scores.argsort()[-K:][::-1] # print(sorted_post_ids) # sorted_ids = [] # relevant_posts = [] # count = 0 # for pid in sorted_post_ids: # if pid in recipe_id_to_row_num: # relevant_posts.append(df.iloc[recipe_id_to_row_num[pid]]) # sorted_ids.append(pid) # count += 1 # if count==max_posts: # break #### header = [gr.Markdown.update(visible=True)] boxes = [ gr.Box.update(visible=True) for _ in relevant_posts ] titles = [ gr.Markdown.update(f"## {post['Name']}") for post in relevant_posts ] toggles = [ gr.Button.update( visible=True, value=SHOW_MORE, interactive=True, ) for _ in relevant_posts ] matches = [ gr.Button.update( value=LIKE_TEXT, interactive=True, ) for _ in relevant_posts ] bodies = [ gr.HTML.update( visible=False, value=f"
" f"

Description:

\n{post['Description']}\n\n" "
" f"

Ingredients:

\n{post['RecipeIngredientParts']}\n\n" "
" f"

Instructions:

\n{post['RecipeInstructions']}\n\n" "
") for post in relevant_posts ] return ( header + boxes + titles + toggles + matches + bodies + [sorted_ids] ) def handle_toggle(toggle): if toggle == SHOW_MORE: new_toggle_text = SHOW_LESS visible = True if toggle == SHOW_LESS: new_toggle_text = SHOW_MORE visible = False return [ gr.Button.update(new_toggle_text), gr.HTML.update(visible=visible), ] def handle_feedback(button_id: int): def register_feedback(doc_ids, query): retrain( query=query, doc_id=doc_ids[button_id] ) return gr.Button.update( value=FEEDBACK_RECEIVED_TEXT, interactive=False, ) return register_feedback default_query = ( "biryani lamb spicy contains cloves and red chili powder, made with ghee and hard boiled eggs, made by grinding coconut and cashew" ) with gr.Blocks() as demo: gr.Markdown(INTRO_MARKDOWN) query = gr.Textbox(value=default_query, label="Query", lines=10) submit = gr.Button(value="Search") header = [gr.Markdown("# Relevant Recipes", visible=False)] post_boxes = [] post_titles = [] toggle_buttons = [] match_buttons = [] post_bodies = [] post_ids = gr.State([]) for i in range(max_posts): with gr.Box(visible=False) as box: post_boxes.append(box) with gr.Row(): with gr.Column(scale=5): title = gr.Markdown("") post_titles.append(title) with gr.Column(scale=1, min_width=370): with gr.Row(): with gr.Column(scale=3, min_width=170): toggle = gr.Button(SHOW_MORE) toggle_buttons.append(toggle) with gr.Column(scale=1, min_width=170): match = gr.Button(LIKE_TEXT) match.click( fn=handle_feedback(button_id=i), inputs=[post_ids, query], outputs=[match], ) match_buttons.append(match) body = gr.HTML("") post_bodies.append(body) toggle.click( fn=handle_toggle, inputs=[toggle], outputs=[toggle, body], ) allblocks = ( header + post_boxes + post_titles + toggle_buttons + match_buttons + post_bodies + [post_ids] ) query.submit( fn=search, inputs=[query], outputs=allblocks) submit.click( fn=search, inputs=[query], outputs=allblocks) demo.launch()