File size: 3,719 Bytes
b2c5d32
 
e3fc3b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c5d32
 
 
 
 
 
 
e3fc3b8
 
 
 
 
 
 
 
 
 
b2c5d32
e3fc3b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c5d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3fc3b8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import List, Dict
import time
import datetime
import uvicorn

model = AutoModelForSeq2SeqLM.from_pretrained("KN123/nl2sql")
tokenizer = AutoTokenizer.from_pretrained("KN123/nl2sql")

def get_prompt(tables, question):
    prompt = f"""convert question and table into SQL query. tables: {tables}. question: {question}"""
    # print(prompt)
    return prompt

def prepare_input(question: str, tables: Dict[str, List[str]]):
    tables = [f"""{table_name}({",".join(tables[table_name])})""" for table_name in tables]
    # print(tables)
    tables = ", ".join(tables)
    # print(tables)
    prompt = get_prompt(tables, question)
    # print(prompt)
    input_ids = tokenizer(prompt, max_length=512, return_tensors="pt").input_ids
    # print(input_ids)
    return input_ids

def inference(question: str, tables: Dict[str, List[str]]) -> str:
    input_data = prepare_input(question=question, tables=tables)
    input_data = input_data.to(model.device)
    outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=512)
    # print("Outputs", outputs)
    result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True)
    return result

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

@app.get("/")
def home():
    return {
        "message" : "Hello there! Everything is working fine!",
        "api-version": "1.0.0",
        "role": "nl2sql",
        "description": "This api can be used to convert natural language to SQL given the human prompt, tables and the attributes."
    }

@app.get("/test-generate")
def generate(text:str):
    start = time.time()
    res = inference("how many people with name jui and age less than 25", {
        "people_name":["id","name"], "people_age": ["people_id","age"]
    })
    end = time.time()
    total_time_taken = end - start
    current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
    current_date = datetime.date.today()
    timezone_name = time.tzname[time.daylight]
    print(res)
    return {
        "api_response": f"{res}",
        "time_taken(s)": f"{total_time_taken}",
        "request_details": {
            "utc_datetime": f"{current_utc_datetime}",
            "current_date": f"{current_date}",
            "timezone_name": f"{timezone_name}"
        }
    }

@app.post("/generate")
def generate(request_body:Dict):
    if 'text' not in request_body or 'tables' not in request_body:
        raise HTTPException(status_code=400, detail="Missing 'text' or 'tables' in request body")
    
    prompt = request_body['text']
    tables = request_body['tables']
    
    start = time.time()
    res = inference(prompt, tables)
    end = time.time()
    total_time_taken = end - start
    current_utc_datetime = datetime.datetime.now(datetime.timezone.utc)
    current_date = datetime.date.today()
    timezone_name = time.tzname[time.daylight]
    print(res)
    return {
        "api_response": f"{res}",
        "time_taken(s)": f"{total_time_taken}",
        "request_details": {
            "utc_datetime": f"{current_utc_datetime}",
            "current_date": f"{current_date}",
            "timezone_name": f"{timezone_name}"
        }
    }



if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8000)