ArduinoProg / app.py
imamnurby's picture
Update app.py
97dfffd
raw
history blame
2.35 kB
from flask import Flask, request, jsonify, render_template
from backend_utils import initialize_all_components, make_predictions
from config import classifier_class_mapping, config
import subprocess
# from flask_cors import CORS, cross_origin
import json
# todo: downgrade version sklearn to 1.0.2
app = Flask(__name__)
# CORS(app)
components = initialize_all_components(config)
db_metadata = components[0]
db_constructor = components[1]
db_params = components[2]
ex_list = components[3]
model_retrieval = components[4]
model_generative = components[5]
tokenizer_generative = components[6]
model_classifier = components[7]
classifier_head = components[8]
tokenizer_classifier = components[9]
def call_predict_api(
input_query,
model_retrieval,
model_generative,
model_classifier, classifier_head,
tokenizer_generative, tokenizer_classifier,
db_metadata, db_constructor, db_params, ex_list,
config
):
'''
wrapper to the make prediction function
'''
predictions = make_predictions(
input_query,
model_retrieval,
model_generative,
model_classifier, classifier_head,
tokenizer_generative, tokenizer_classifier,
db_metadata, db_constructor, db_params, ex_list,
config
)
return predictions
@app.route("/")
def hello_world():
return render_template("index.html")
@app.route('/predict', methods=['POST', 'GET'])
def predict():
#request_data = request.get_json()
#user_query = request_data.get('user_query', None)
user_query = request.args.get("user_query")
print(f"user_query: {user_query}")
if user_query != None:
print("predicting")
predictions = call_predict_api(
user_query,
model_retrieval,
model_generative,
model_classifier, classifier_head,
tokenizer_generative, tokenizer_classifier,
db_metadata, db_constructor, db_params, ex_list,
config
)
# print(predictions)
if type(predictions) == str:
if predictions == 'null':
return jsonify({'predictions': 'null'})
return jsonify({
'predictions': predictions
})
if __name__ == '__main__':
app.run(host="0.0.0.0", port=7860)