Spaces:
Build error
Build error
from flask import Flask, request, jsonify, render_template | |
from backend_utils import initialize_all_components, make_predictions | |
from config import classifier_class_mapping, config | |
# from flask_cors import CORS, cross_origin | |
import json | |
# todo: downgrade version sklearn to 1.0.2 | |
app = Flask(__name__) | |
# CORS(app) | |
components = initialize_all_components(config) | |
db_metadata = components[0] | |
db_constructor = components[1] | |
db_params = components[2] | |
model_retrieval = components[3] | |
model_generative = components[4] | |
tokenizer_generative = components[5] | |
model_classifier = components[6] | |
classifier_head = components[7] | |
tokenizer_classifier = components[8] | |
def call_predict_api( | |
input_query, | |
model_retrieval, | |
model_generative, | |
model_classifier, classifier_head, | |
tokenizer_generative, tokenizer_classifier, | |
db_metadata, db_constructor, db_params, | |
config | |
): | |
''' | |
wrapper to the make prediction function | |
''' | |
predictions = make_predictions( | |
input_query, | |
model_retrieval, | |
model_generative, | |
model_classifier, classifier_head, | |
tokenizer_generative, tokenizer_classifier, | |
db_metadata, db_constructor, db_params, | |
config | |
) | |
return predictions | |
def hello_world(): | |
return render_template("index.html") | |
def predict(): | |
#request_data = request.get_json() | |
#user_query = request_data.get('user_query', None) | |
user_query = request.args.get("user_query") | |
print(f"user_query: {user_query}") | |
if user_query != None: | |
print("predicting") | |
predictions = call_predict_api( | |
user_query, | |
model_retrieval, | |
model_generative, | |
model_classifier, classifier_head, | |
tokenizer_generative, tokenizer_classifier, | |
db_metadata, db_constructor, db_params, | |
config | |
) | |
print("success prediction!") | |
return jsonify({ | |
'predictions': predictions | |
}) | |
if __name__ == '__main__': | |
app.run(host="0.0.0.0", port=7860) | |