omkar334 commited on
Commit
7d7a6a9
1 Parent(s): 44493b6

basic agent

Browse files
Files changed (1) hide show
  1. agent.py +87 -0
agent.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+ from strictjson import strict_json_async
5
+
6
+ from sarvam import speaker, translator
7
+
8
+ load_dotenv()
9
+
10
+ RAG_SYS_PROMPT = None
11
+ RAG_USER_PROMPT = None
12
+
13
+ AGENT_PROMPT = """You are an AI agent.
14
+ You are given three functions - retriever (Retreives information from a database), translator and a speaker (converts text to speech).
15
+ The database is a Grade {} {} Textbook. Your task is to assess the user query and determine which function to call.
16
+ If the function is to be called, return response as None. If any function is not needed, you can answer to the query yourself. Also identify keywords in the query,
17
+ """
18
+
19
+
20
+ async def llm(system_prompt: str, user_prompt: str) -> str:
21
+ from groq import AsyncGroq
22
+
23
+ client = AsyncGroq(api_key=os.get_env("GROQ_API_KEY"))
24
+
25
+ messages = [
26
+ {"role": "system", "content": system_prompt},
27
+ {"role": "user", "content": user_prompt},
28
+ ]
29
+
30
+ chat_completion = await client.chat.completions.create(
31
+ messages=messages,
32
+ model="llama3-70b-8192",
33
+ temperature=0.3,
34
+ max_tokens=360,
35
+ top_p=1,
36
+ stop=None,
37
+ stream=False,
38
+ )
39
+
40
+ return chat_completion.choices[0].message.content
41
+
42
+
43
+ async def call_agent(user_prompt, grade, subject):
44
+ system_prompt = AGENT_PROMPT.format(grade, subject)
45
+
46
+ result = await strict_json_async(
47
+ system_prompt=system_prompt,
48
+ user_prompt=user_prompt,
49
+ output_format={
50
+ "function": 'Type of function to call, type: Enum["retriever", "translator", "speaker", "none"]',
51
+ "keywords": "Array of keywords, type: List[str]",
52
+ "src_lang": "Identify the language that the user query is in, type: str",
53
+ "dest_lang": """Identify the target language from the user query if the function is either "translator" or "speaker". If language is not found, return "none",
54
+ type: Enum["hindi", "bengali", "kannada", "malayalam", "marathi", "odia", "punjabi", "tamil", "telugu", "english", "gujarati", "none"]""",
55
+ "source": "Identify the sentence that the user wants to translate or speak. Retu 'none', type: Optional[str]",
56
+ "response": "Your response, type: Optional[str]",
57
+ },
58
+ llm=llm,
59
+ )
60
+ return result
61
+
62
+
63
+ async def function_caller(user_prompt, grade, subject, client):
64
+ result = call_agent(user_prompt, grade, subject)
65
+ function = result["function"].lower()
66
+
67
+ if function == "none":
68
+ return result["response"]
69
+
70
+ elif function == "retriever":
71
+ collection = f"{grade}_{subject}"
72
+
73
+ data = client.search(collection, user_prompt)
74
+ data = [i.document for i in data]
75
+
76
+ system_prompt = RAG_SYS_PROMPT.format(grade, subject)
77
+ user_prompt = RAG_USER_PROMPT.format(user_prompt)
78
+
79
+ response = await llm(system_prompt, user_prompt)
80
+
81
+ return response
82
+
83
+ elif function == "translator":
84
+ return await translator(result["response"], result["src_lang"], result["dest_lang"])
85
+
86
+ elif function == "speaker":
87
+ return await speaker(result["response"], result["dest_lang"])