File size: 2,000 Bytes
c679fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from flask import Flask, request, jsonify, render_template
from backend_utils import initialize_all_components, make_predictions
from config import classifier_class_mapping, config
# from flask_cors import CORS, cross_origin
import json


# todo: downgrade version sklearn to 1.0.2

app = Flask(__name__)
# CORS(app)
components = initialize_all_components(config)
db_metadata = components[0]
db_constructor = components[1]
model_retrieval = components[2]
model_generative = components[3]
tokenizer_generative = components[4]
model_classifier = components[5]
classifier_head = components[6]
tokenizer_classifier = components[7]

def call_predict_api(
    input_query, 
    model_retrieval, 
    model_generative,  
    model_classifier, classifier_head,
    tokenizer_generative, tokenizer_classifier,
    db_metadata, db_constructor,
    config
    ):
    '''
    wrapper to the make prediction function
    '''
    predictions = make_predictions(
        input_query, 
        model_retrieval, 
        model_generative,  
        model_classifier, classifier_head,
        tokenizer_generative, tokenizer_classifier,
        db_metadata, db_constructor,
        config
    )
    return predictions

@app.route("/")
def hello_world():
    return render_template("index.html")

@app.route('/predict', methods=['POST'])
def predict():
    request_data = request.get_json()
    user_query = request_data.get('user_query', None)

    if user_query != None:
        predictions = call_predict_api(
                user_query,
                model_retrieval,
                model_generative,
                model_classifier, classifier_head,
                tokenizer_generative, tokenizer_classifier,
                db_metadata, db_constructor,
                config
            )
        with open("prediction.txt", 'w') as f:
            json.dump(predictions, f)
        return {
            'predictions': predictions
        }

if __name__ == '__main__':
    app.run(host="0.0.0.0", port=7860)