File size: 3,792 Bytes
f087194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd603de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import os
import pymongo
import spaces


from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM


def get_embedding(text: str) -> list[float]:
    if not text.strip():
        print("Attempted to get embedding for empty text.")
        return []

    embedding = embedding_model.encode(text)

    return embedding.tolist()


def get_mongo_client(mongo_uri):
    """Establish connection to the MongoDB."""
    try:
        client = pymongo.MongoClient(mongo_uri)
        print("Connection to MongoDB successful")
        return client
    except pymongo.errors.ConnectionFailure as e:
        print(f"Connection failed: {e}")
        return None


def vector_search(user_query, collection):

    # Generate embedding for the user query
    query_embedding = get_embedding(user_query)

    if query_embedding is None:
        return "Invalid query or embedding generation failed."

    # Define the vector search pipeline
    pipeline = [
        {
            "$vectorSearch": {
                "index": "vector_index",
                "queryVector": query_embedding,
                "path": "embedding",
                "numCandidates": 150,  # Number of candidate matches to consider
                "limit": 4,  # Return top 4 matches
            }
        },
        {
            "$project": {
                "_id": 0,
                "title": 1,
                "ingredients": 1,
                "directions": 1,
                "score": {"$meta": "vectorSearchScore"},  # Include the search score
            }
        },
    ]

    # Execute the search
    results = collection.aggregate(pipeline)
    return list(results)


def get_search_result(query, collection):

    get_knowledge = vector_search(query, collection)

    search_result = ""
    for result in get_knowledge:
        search_result += f"Recipe Name: {result.get('title', 'N/A')}, Ingredients: {result.get('ingredients', 'N/A')}, Directions: {result.get('directions', 'N/A')}\n"

    return search_result, get_knowledge


@spaces.GPU
def process_response(message, history):
  source_information, matches = get_search_result(message, collection)
  recipe_dict = {}
  for x in matches:
    name = x.pop("title")
    recipe_dict[name] = x

  combined_information = f"Query: {message}\nContinue to answer the query by using the Search Results:\n{source_information}."
  input_ids = tokenizer(combined_information, return_tensors="pt").to("cuda")
  response = model.generate(**input_ids, max_new_tokens=500)
  response_text = tokenizer.decode(response[0]).split("\n.\n")[-1].split("<eos>")[0].strip()

  matched_recipe = ""
  for title in recipe_dict.keys():
    if title in response_text:
      matched_recipe = title
      break
  if not matched_recipe:
    matched_recipe = next(iter(recipe_dict))
  recipe = recipe_dict[matched_recipe]

  response_text += f"\n\nRecipe for **{matched_recipe}**:"
  response_text += "\n### List of ingredients:\n- {0}".format("\n- ".join(recipe["ingredients"].split(", ")))
  response_text += "\n### Directions:\n- {0}".format(".\n- ".join(recipe["directions"].split(". ")))

  return response_text


if __name__ == "__main__":
    
    embedding_model = SentenceTransformer("thenlper/gte-large")

    mongo_uri = os.getenv("MONGO_URI")
    if not mongo_uri:
        raise ValueError("MONGO_URI not set in environment variables")

    mongo_client = get_mongo_client(mongo_uri)
    
    # Ingest data into MongoDB
    db = mongo_client["recipe"]
    collection = db["recipe_collection"]

    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", device_map="auto")
    
    gr.ChatInterface(process_response).queue().launch()