recipe-rag / app.py
adbcode's picture
Update app.py
cd603de verified
import gradio as gr
import os
import pymongo
import spaces
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
def get_embedding(text: str) -> list[float]:
if not text.strip():
print("Attempted to get embedding for empty text.")
return []
embedding = embedding_model.encode(text)
return embedding.tolist()
def get_mongo_client(mongo_uri):
"""Establish connection to the MongoDB."""
try:
client = pymongo.MongoClient(mongo_uri)
print("Connection to MongoDB successful")
return client
except pymongo.errors.ConnectionFailure as e:
print(f"Connection failed: {e}")
return None
def vector_search(user_query, collection):
# Generate embedding for the user query
query_embedding = get_embedding(user_query)
if query_embedding is None:
return "Invalid query or embedding generation failed."
# Define the vector search pipeline
pipeline = [
{
"$vectorSearch": {
"index": "vector_index",
"queryVector": query_embedding,
"path": "embedding",
"numCandidates": 150, # Number of candidate matches to consider
"limit": 4, # Return top 4 matches
}
},
{
"$project": {
"_id": 0,
"title": 1,
"ingredients": 1,
"directions": 1,
"score": {"$meta": "vectorSearchScore"}, # Include the search score
}
},
]
# Execute the search
results = collection.aggregate(pipeline)
return list(results)
def get_search_result(query, collection):
get_knowledge = vector_search(query, collection)
search_result = ""
for result in get_knowledge:
search_result += f"Recipe Name: {result.get('title', 'N/A')}, Ingredients: {result.get('ingredients', 'N/A')}, Directions: {result.get('directions', 'N/A')}\n"
return search_result, get_knowledge
@spaces.GPU
def process_response(message, history):
source_information, matches = get_search_result(message, collection)
recipe_dict = {}
for x in matches:
name = x.pop("title")
recipe_dict[name] = x
combined_information = f"Query: {message}\nContinue to answer the query by using the Search Results:\n{source_information}."
input_ids = tokenizer(combined_information, return_tensors="pt").to("cuda")
response = model.generate(**input_ids, max_new_tokens=500)
response_text = tokenizer.decode(response[0]).split("\n.\n")[-1].split("<eos>")[0].strip()
matched_recipe = ""
for title in recipe_dict.keys():
if title in response_text:
matched_recipe = title
break
if not matched_recipe:
matched_recipe = next(iter(recipe_dict))
recipe = recipe_dict[matched_recipe]
response_text += f"\n\nRecipe for **{matched_recipe}**:"
response_text += "\n### List of ingredients:\n- {0}".format("\n- ".join(recipe["ingredients"].split(", ")))
response_text += "\n### Directions:\n- {0}".format(".\n- ".join(recipe["directions"].split(". ")))
return response_text
if __name__ == "__main__":
embedding_model = SentenceTransformer("thenlper/gte-large")
mongo_uri = os.getenv("MONGO_URI")
if not mongo_uri:
raise ValueError("MONGO_URI not set in environment variables")
mongo_client = get_mongo_client(mongo_uri)
# Ingest data into MongoDB
db = mongo_client["recipe"]
collection = db["recipe_collection"]
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto")
gr.ChatInterface(process_response).queue().launch()