Spaces:
Running
Running
File size: 3,719 Bytes
b2c5d32 e3fc3b8 b2c5d32 e3fc3b8 b2c5d32 e3fc3b8 b2c5d32 e3fc3b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import List, Dict
import time
import datetime
import uvicorn
model = AutoModelForSeq2SeqLM.from_pretrained("KN123/nl2sql")
tokenizer = AutoTokenizer.from_pretrained("KN123/nl2sql")
def get_prompt(tables, question):
prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
# print(prompt)
return prompt
def prepare_input(question: str, tables: Dict[str, List[str]]):
tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
# print(tables)
tables = ", ".join(tables)
# print(tables)
prompt = get_prompt(tables, question)
# print(prompt)
input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
# print(input_ids)
return input_ids
def inference(question: str, tables: Dict[str, List[str]]) -> str:
input_data = prepare_input(question=question, tables=tables)
input_data = input_data.to(model.device)
outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
# print("Outputs", outputs)
result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
return result
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], # Allows all methods
allow_headers=["*"], # Allows all headers
)
@app.get("/")
def home():
return {
"message" : "Hello there! Everything is working fine!",
"api-version": "1.0.0",
"role": "nl2sql",
"description": "This api can be used to convert natural language to SQL given the human prompt, tables and the attributes."
}
@app.get("/test-generate")
def generate(text:str):
start = time.time()
res = inference("how many people with name jui and age less than 25", {
"people_name":["id","name"], "people_age": ["people_id","age"]
})
end = time.time()
total_time_taken = end - start
current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
current_date = datetime.date.today()
timezone_name = time.tzname[time.daylight]
print(res)
return {
"api_response": f"{res}",
"time_taken(s)": f"{total_time_taken}",
"request_details": {
"utc_datetime": f"{current_utc_datetime}",
"current_date": f"{current_date}",
"timezone_name": f"{timezone_name}"
}
}
@app.post("/generate")
def generate(request_body:Dict):
if 'text' not in request_body or 'tables' not in request_body:
raise HTTPException(status_code=400, detail="Missing 'text' or 'tables' in request body")
prompt = request_body['text']
tables = request_body['tables']
start = time.time()
res = inference(prompt, tables)
end = time.time()
total_time_taken = end - start
current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
current_date = datetime.date.today()
timezone_name = time.tzname[time.daylight]
print(res)
return {
"api_response": f"{res}",
"time_taken(s)": f"{total_time_taken}",
"request_details": {
"utc_datetime": f"{current_utc_datetime}",
"current_date": f"{current_date}",
"timezone_name": f"{timezone_name}"
}
}
if __name__ == "__main__":
uvicorn.run(app, host="127.0.0.1", port=8000) |