ArduinoProg / app.py
imamnurby's picture
Update app.py
bd1a9b4
raw
history blame
2.05 kB
from flask import Flask, request, jsonify, render_template
from backend_utils import initialize_all_components, make_predictions
from config import classifier_class_mapping, config
# from flask_cors import CORS, cross_origin
import json
# todo: downgrade version sklearn to 1.0.2
app = Flask(__name__)
# CORS(app)
components = initialize_all_components(config)
db_metadata = components[0]
db_constructor = components[1]
model_retrieval = components[2]
model_generative = components[3]
tokenizer_generative = components[4]
model_classifier = components[5]
classifier_head = components[6]
tokenizer_classifier = components[7]
def call_predict_api(
input_query,
model_retrieval,
model_generative,
model_classifier, classifier_head,
tokenizer_generative, tokenizer_classifier,
db_metadata, db_constructor,
config
):
'''
wrapper to the make prediction function
'''
predictions = make_predictions(
input_query,
model_retrieval,
model_generative,
model_classifier, classifier_head,
tokenizer_generative, tokenizer_classifier,
db_metadata, db_constructor,
config
)
return predictions
@app.route("/")
def hello_world():
return render_template("index.html")
@app.route('/predict', methods=['POST', 'GET'])
def predict():
#request_data = request.get_json()
#user_query = request_data.get('user_query', None)
user_query = request.args.get("user_query")
print(f"user_query: {user_query}")
if user_query != None:
predictions = call_predict_api(
user_query,
model_retrieval,
model_generative,
model_classifier, classifier_head,
tokenizer_generative, tokenizer_classifier,
db_metadata, db_constructor,
config
)
print("success prediction!")
return {
'predictions': predictions
}
if __name__ == '__main__':
app.run(host="0.0.0.0", port=7860)