barathm111's picture
Upload 5 files
6d901c3 verified
raw
history blame
2.47 kB
import os
from flask import Flask, request, jsonify, render_template
from transformers import pipeline
import mysql.connector
from groq import Groq
app = Flask(name)
# Initialize the text generation pipeline
pipe = pipeline("text-generation", model="defog/sqlcoder-7b-2")
# Initialize the Groq client
groq_client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# Database connection details
DB_CONFIG = {
'host': 'auth-db579.hstgr.io',
'user': 'u121769371_ki_aiml_test',
'password': os.environ.get("DB_PASSWORD"),
'database': 'u121769371_ki_aiml_test'
}
def generate_sql(text):
output = pipe(text, max_new_tokens=50)
return output[0]['generated_text']
def execute_query(query):
try:
connection = mysql.connector.connect(**DB_CONFIG)
cursor = connection.cursor()
cursor.execute(query)
results = cursor.fetchall()
cursor.close()
connection.close()
return results
except mysql.connector.Error as err:
print(f"Error: {err}")
return None
@app.route('/')
def index():
return render_template('index.html')
@app.route('/chatbot', methods=['POST'])
def chatbot():
data = request.json
user_query = data.get('text')
if not user_query:
return jsonify({"error": "No query provided"}), 400
try:
# Step 1: Convert natural language to SQL
sql_query = generate_sql(user_query)
# Step 2: Execute SQL query
query_result = execute_query(sql_query)
if query_result is None:
return jsonify({"error": "Database query execution failed"}), 500
# Step 3: Generate natural language response using Groq
prompt = f"Original query: {user_query}\nSQL query: {sql_query}\nQuery result: {query_result}\nPlease provide a natural language summary of the query result."
chat_completion = groq_client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
model="llama3-8b-8192",
)
natural_language_response = chat_completion.choices[0].message.content
return jsonify({"response": natural_language_response})
except Exception as e:
return jsonify({"error": str(e)}), 500
if name == 'main':
app.run(host='0.0.0.0', port=8000)