imamnurby commited on
Commit
c679fb3
1 Parent(s): e34bb70

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, render_template
2
+ from backend_utils import initialize_all_components, make_predictions
3
+ from config import classifier_class_mapping, config
4
+ # from flask_cors import CORS, cross_origin
5
+ import json
6
+
7
+
8
+ # todo: downgrade version sklearn to 1.0.2
9
+
10
+ app = Flask(__name__)
11
+ # CORS(app)
12
+ components = initialize_all_components(config)
13
+ db_metadata = components[0]
14
+ db_constructor = components[1]
15
+ model_retrieval = components[2]
16
+ model_generative = components[3]
17
+ tokenizer_generative = components[4]
18
+ model_classifier = components[5]
19
+ classifier_head = components[6]
20
+ tokenizer_classifier = components[7]
21
+
22
+ def call_predict_api(
23
+ input_query,
24
+ model_retrieval,
25
+ model_generative,
26
+ model_classifier, classifier_head,
27
+ tokenizer_generative, tokenizer_classifier,
28
+ db_metadata, db_constructor,
29
+ config
30
+ ):
31
+ '''
32
+ wrapper to the make prediction function
33
+ '''
34
+ predictions = make_predictions(
35
+ input_query,
36
+ model_retrieval,
37
+ model_generative,
38
+ model_classifier, classifier_head,
39
+ tokenizer_generative, tokenizer_classifier,
40
+ db_metadata, db_constructor,
41
+ config
42
+ )
43
+ return predictions
44
+
45
+ @app.route("/")
46
+ def hello_world():
47
+ return render_template("index.html")
48
+
49
+ @app.route('/predict', methods=['POST'])
50
+ def predict():
51
+ request_data = request.get_json()
52
+ user_query = request_data.get('user_query', None)
53
+
54
+ if user_query != None:
55
+ predictions = call_predict_api(
56
+ user_query,
57
+ model_retrieval,
58
+ model_generative,
59
+ model_classifier, classifier_head,
60
+ tokenizer_generative, tokenizer_classifier,
61
+ db_metadata, db_constructor,
62
+ config
63
+ )
64
+ with open("prediction.txt", 'w') as f:
65
+ json.dump(predictions, f)
66
+ return {
67
+ 'predictions': predictions
68
+ }
69
+
70
+ if __name__ == '__main__':
71
+ app.run(host="0.0.0.0", port=7860)