Spaces:
Runtime error
Runtime error
endpoints and models
Browse files
agent.py
CHANGED
@@ -32,7 +32,9 @@ async def llm(system_prompt: str, user_prompt: str) -> str:
|
|
32 |
return chat_completion.choices[0].message.content
|
33 |
|
34 |
|
35 |
-
async def call_agent(user_prompt,
|
|
|
|
|
36 |
system_prompt = AGENT_PROMPT.format(grade, subject)
|
37 |
|
38 |
result = await strict_json_async(
|
@@ -52,23 +54,27 @@ async def call_agent(user_prompt, grade, subject):
|
|
52 |
return result
|
53 |
|
54 |
|
55 |
-
async def
|
56 |
grade, subject, chapter = collection.split("_")
|
57 |
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
function = result["function"].lower()
|
60 |
|
61 |
if function == "none":
|
62 |
-
return result["response"]
|
63 |
|
64 |
elif function == "retriever":
|
65 |
-
|
66 |
-
data = [i.document for i in data]
|
67 |
-
|
68 |
-
system_prompt = RAG_SYS_PROMPT.format(subject, grade)
|
69 |
-
user_prompt = RAG_USER_PROMPT.format(data, user_prompt)
|
70 |
-
|
71 |
-
response = await llm(system_prompt, user_prompt)
|
72 |
return {"text": response}
|
73 |
|
74 |
elif function == "translator":
|
|
|
32 |
return chat_completion.choices[0].message.content
|
33 |
|
34 |
|
35 |
+
async def call_agent(user_prompt, collection):
|
36 |
+
grade, subject, chapter = collection.split("_")
|
37 |
+
|
38 |
system_prompt = AGENT_PROMPT.format(grade, subject)
|
39 |
|
40 |
result = await strict_json_async(
|
|
|
54 |
return result
|
55 |
|
56 |
|
57 |
+
async def retriever(user_prompt, collection, client):
|
58 |
grade, subject, chapter = collection.split("_")
|
59 |
|
60 |
+
data = client.search(collection, user_prompt)
|
61 |
+
data = [i.document for i in data]
|
62 |
+
|
63 |
+
system_prompt = RAG_SYS_PROMPT.format(subject, grade)
|
64 |
+
user_prompt = RAG_USER_PROMPT.format(data, user_prompt)
|
65 |
+
|
66 |
+
return await llm(system_prompt, user_prompt)
|
67 |
+
|
68 |
+
|
69 |
+
async def function_caller(user_prompt, collection, client):
|
70 |
+
result = await call_agent(user_prompt, collection)
|
71 |
function = result["function"].lower()
|
72 |
|
73 |
if function == "none":
|
74 |
+
return {"text": result["response"]}
|
75 |
|
76 |
elif function == "retriever":
|
77 |
+
response = await retriever(user_prompt, collection, client)
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
return {"text": response}
|
79 |
|
80 |
elif function == "translator":
|
app.py
CHANGED
@@ -6,8 +6,9 @@ from fastapi import FastAPI
|
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
from pydantic import BaseModel
|
8 |
|
9 |
-
from agent import function_caller
|
10 |
from client import HybridClient
|
|
|
11 |
|
12 |
app = FastAPI()
|
13 |
hclient = HybridClient()
|
@@ -23,17 +24,46 @@ app.add_middleware(
|
|
23 |
|
24 |
class ChatQuery(BaseModel):
|
25 |
query: str
|
26 |
-
|
|
|
|
|
27 |
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
32 |
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
collection = f"{grade}_{subject.lower()}_{chapter}"
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
if "text" in response:
|
39 |
output = response["text"]
|
|
|
6 |
from fastapi.middleware.cors import CORSMiddleware
|
7 |
from pydantic import BaseModel
|
8 |
|
9 |
+
from agent import function_caller, retriever
|
10 |
from client import HybridClient
|
11 |
+
from sarvam import speaker, translator
|
12 |
|
13 |
app = FastAPI()
|
14 |
hclient = HybridClient()
|
|
|
24 |
|
25 |
class ChatQuery(BaseModel):
|
26 |
query: str
|
27 |
+
grade: str
|
28 |
+
subject: str
|
29 |
+
chapter: str
|
30 |
|
31 |
|
32 |
+
class TranslateQuery(BaseModel):
|
33 |
+
text: str
|
34 |
+
src: str
|
35 |
+
dest: str
|
36 |
|
37 |
|
38 |
+
class TTSQuery(BaseModel):
|
39 |
+
text: str
|
40 |
+
src: str
|
41 |
+
|
42 |
+
|
43 |
+
@app.get("/agent")
|
44 |
+
async def agent(query: ChatQuery):
|
45 |
collection = f"{grade}_{subject.lower()}_{chapter}"
|
46 |
+
return await function_caller(query.query, collection, hclient)
|
47 |
+
|
48 |
+
|
49 |
+
@app.get("/rag")
|
50 |
+
async def rag(query: ChatQuery):
|
51 |
+
collection = f"{grade}_{subject.lower()}_{chapter}"
|
52 |
+
return await retriever(query.query, collection, hclient)
|
53 |
+
|
54 |
+
|
55 |
+
@app.get("/translate")
|
56 |
+
async def translate(query: TranslateQuery):
|
57 |
+
return await translator(query.text, query.src, query.dest)
|
58 |
+
|
59 |
+
|
60 |
+
@app.get("/tts")
|
61 |
+
async def tts(query: TTSQuery):
|
62 |
+
return await speaker(query.text, query.src)
|
63 |
+
|
64 |
+
|
65 |
+
async def gradio_interface(input_text, grade, subject, chapter, history):
|
66 |
+
response = await agent(ChatQuery(query=input_text, grade=grade, subject=subject, chapter=chapter))
|
67 |
|
68 |
if "text" in response:
|
69 |
output = response["text"]
|
sarvam.py
CHANGED
@@ -40,13 +40,13 @@ async def translator(text, src, dest):
|
|
40 |
return {"text": output["translated_text"]}
|
41 |
|
42 |
|
43 |
-
async def speaker(text,
|
44 |
async with aiohttp.ClientSession() as session:
|
45 |
url = "https://api.sarvam.ai/text-to-speech"
|
46 |
|
47 |
payload = {
|
48 |
"inputs": [text],
|
49 |
-
"target_language_code": code_map[
|
50 |
"speaker": "meera",
|
51 |
"pitch": 0,
|
52 |
"pace": 1.25,
|
|
|
40 |
return {"text": output["translated_text"]}
|
41 |
|
42 |
|
43 |
+
async def speaker(text, src="hindi"):
|
44 |
async with aiohttp.ClientSession() as session:
|
45 |
url = "https://api.sarvam.ai/text-to-speech"
|
46 |
|
47 |
payload = {
|
48 |
"inputs": [text],
|
49 |
+
"target_language_code": code_map[src],
|
50 |
"speaker": "meera",
|
51 |
"pitch": 0,
|
52 |
"pace": 1.25,
|