from flask import Flask, request, jsonify, render_template from backend_utils import initialize_all_components, make_predictions from config import classifier_class_mapping, config # from flask_cors import CORS, cross_origin import json # todo: downgrade version sklearn to 1.0.2 app = Flask(__name__) # CORS(app) components = initialize_all_components(config) db_metadata = components[0] db_constructor = components[1] model_retrieval = components[2] model_generative = components[3] tokenizer_generative = components[4] model_classifier = components[5] classifier_head = components[6] tokenizer_classifier = components[7] def call_predict_api( input_query, model_retrieval, model_generative, model_classifier, classifier_head, tokenizer_generative, tokenizer_classifier, db_metadata, db_constructor, config ): ''' wrapper to the make prediction function ''' predictions = make_predictions( input_query, model_retrieval, model_generative, model_classifier, classifier_head, tokenizer_generative, tokenizer_classifier, db_metadata, db_constructor, config ) return predictions @app.route("/") def hello_world(): return render_template("index.html") @app.route('/predict', methods=['POST']) def predict(): request_data = request.get_json() user_query = request_data.get('user_query', None) if user_query != None: predictions = call_predict_api( user_query, model_retrieval, model_generative, model_classifier, classifier_head, tokenizer_generative, tokenizer_classifier, db_metadata, db_constructor, config ) with open("prediction.txt", 'w') as f: json.dump(predictions, f) return { 'predictions': predictions } if __name__ == '__main__': app.run(host="0.0.0.0", port=7860)