Spaces:
Build error
Build error
File size: 2,000 Bytes
c679fb3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from flask import Flask, request, jsonify, render_template
from backend_utils import initialize_all_components, make_predictions
from config import classifier_class_mapping, config
# from flask_cors import CORS, cross_origin
import json
# todo: downgrade version sklearn to 1.0.2
app = Flask(__name__)
# CORS(app)
components = initialize_all_components(config)
db_metadata = components[0]
db_constructor = components[1]
model_retrieval = components[2]
model_generative = components[3]
tokenizer_generative = components[4]
model_classifier = components[5]
classifier_head = components[6]
tokenizer_classifier = components[7]
def call_predict_api(
input_query,
model_retrieval,
model_generative,
model_classifier, classifier_head,
tokenizer_generative, tokenizer_classifier,
db_metadata, db_constructor,
config
):
'''
wrapper to the make prediction function
'''
predictions = make_predictions(
input_query,
model_retrieval,
model_generative,
model_classifier, classifier_head,
tokenizer_generative, tokenizer_classifier,
db_metadata, db_constructor,
config
)
return predictions
@app.route("/")
def hello_world():
return render_template("index.html")
@app.route('/predict', methods=['POST'])
def predict():
request_data = request.get_json()
user_query = request_data.get('user_query', None)
if user_query != None:
predictions = call_predict_api(
user_query,
model_retrieval,
model_generative,
model_classifier, classifier_head,
tokenizer_generative, tokenizer_classifier,
db_metadata, db_constructor,
config
)
with open("prediction.txt", 'w') as f:
json.dump(predictions, f)
return {
'predictions': predictions
}
if __name__ == '__main__':
app.run(host="0.0.0.0", port=7860)
|