Spaces:
Sleeping
Sleeping
barathm111
commited on
Commit
•
9dfd473
1
Parent(s):
130beb7
Upload app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,8 @@ class QueryRequest(BaseModel):
|
|
13 |
@app.get("/")
|
14 |
def home():
|
15 |
return {"message": "SQL Generation Server is running"}
|
|
|
|
|
16 |
@app.post("/generate")
|
17 |
def generate(request: QueryRequest):
|
18 |
try:
|
@@ -22,29 +24,15 @@ def generate(request: QueryRequest):
|
|
22 |
|
23 |
generated_text = output[0]['generated_text']
|
24 |
sql_query = generated_text.split("SQL query:")[-1].strip()
|
25 |
-
|
26 |
-
|
|
|
27 |
raise ValueError("Generated text is not a valid SQL query")
|
28 |
|
29 |
-
# Further validation to ensure no additional text
|
30 |
-
sql_query = sql_query.split(';')[0].strip()
|
31 |
-
|
32 |
-
# Comprehensive list of SQL keywords
|
33 |
-
allowed_keywords = {
|
34 |
-
'select', 'insert', 'update', 'delete', 'show', 'describe', 'from', 'where', 'and', 'or', 'like', 'limit', 'order by', 'group by', 'join', 'inner join', 'left join', 'right join', 'full join', 'on', 'using', 'union', 'union all', 'distinct', 'having', 'into', 'values', 'set', 'create', 'alter', 'drop', 'table', 'database', 'index', 'view', 'trigger', 'procedure', 'function', 'if', 'exists', 'primary key', 'foreign key', 'references', 'check', 'constraint', 'default', 'auto_increment', 'null', 'not null', 'in', 'is', 'is not', 'between', 'case', 'when', 'then', 'else', 'end', 'asc', 'desc', 'count', 'sum', 'avg', 'min', 'max', 'timestamp', 'date', 'time', 'varchar', 'char', 'int', 'integer', 'smallint', 'bigint', 'decimal', 'numeric', 'float', 'real', 'double', 'boolean', 'enum', 'text', 'blob', 'clob'
|
35 |
-
}
|
36 |
-
# Ensure the query only contains allowed keywords
|
37 |
-
tokens = sql_query.lower().split()
|
38 |
-
for token in tokens:
|
39 |
-
if not any(token.startswith(keyword) for keyword in allowed_keywords):
|
40 |
-
raise ValueError("Generated text contains invalid SQL syntax")
|
41 |
-
|
42 |
return {"output": sql_query}
|
43 |
except Exception as e:
|
44 |
raise HTTPException(status_code=500, detail=str(e))
|
45 |
|
46 |
-
|
47 |
-
|
48 |
if __name__ == "__main__":
|
49 |
import uvicorn
|
50 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
13 |
@app.get("/")
|
14 |
def home():
|
15 |
return {"message": "SQL Generation Server is running"}
|
16 |
+
|
17 |
+
|
18 |
@app.post("/generate")
|
19 |
def generate(request: QueryRequest):
|
20 |
try:
|
|
|
24 |
|
25 |
generated_text = output[0]['generated_text']
|
26 |
sql_query = generated_text.split("SQL query:")[-1].strip()
|
27 |
+
|
28 |
+
# Basic validation
|
29 |
+
if not sql_query.lower().startswith(('select', 'show', 'describe')):
|
30 |
raise ValueError("Generated text is not a valid SQL query")
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
return {"output": sql_query}
|
33 |
except Exception as e:
|
34 |
raise HTTPException(status_code=500, detail=str(e))
|
35 |
|
|
|
|
|
36 |
if __name__ == "__main__":
|
37 |
import uvicorn
|
38 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|