Erfan11 commited on
Commit
5cf5335
1 Parent(s): 4db40a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -30
app.py CHANGED
@@ -1,36 +1,24 @@
1
  import os
2
- from transformers import TFBertForSequenceClassification, BertTokenizerFast
 
 
3
 
4
- def load_model(model_name):
5
- try:
6
- # Load TensorFlow model from Hugging Face
7
- model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY'))
8
- except OSError:
9
- # Fallback to PyTorch model if TensorFlow fails
10
- model = TFBertForSequenceClassification.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY'), from_pt=True)
11
- return model
12
 
13
- def load_tokenizer(model_name):
14
- tokenizer = BertTokenizerFast.from_pretrained(model_name, use_auth_token=os.getenv('API_KEY'))
15
- return tokenizer
16
 
17
- def predict(text, model, tokenizer):
18
- inputs = tokenizer(text, return_tensors="tf")
19
- outputs = model(**inputs)
20
- return outputs
21
 
22
- def main():
23
- model_name = os.getenv('MODEL_PATH')
24
- if model_name is None:
25
- raise ValueError("MODEL_PATH environment variable not set or is None")
 
 
 
26
 
27
- model = load_model(model_name)
28
- tokenizer = load_tokenizer(model_name)
29
-
30
- # Example prediction
31
- text = "Sample input text"
32
- result = predict(text, model, tokenizer)
33
- print(result)
34
-
35
- if __name__ == "__main__":
36
- main()
 
1
  import os
2
+ from flask import Flask, request, jsonify
3
+ from dotenv import load_dotenv
4
+ import tensorflow as tf
5
 
6
+ load_dotenv()
7
+ api_key = os.getenv('HF_API_KEY')
8
+ model_path = os.getenv('MODEL_PATH')
 
 
 
 
 
9
 
10
+ app = Flask(__name__)
 
 
11
 
12
+ def load_model():
13
+ return tf.keras.models.load_model(model_path)
 
 
14
 
15
+ @app.route('/predict', methods=['POST'])
16
+ def predict():
17
+ data = request.get_json()
18
+ text = data['text']
19
+ model = load_model()
20
+ prediction = model.predict([text])
21
+ return jsonify(prediction.tolist())
22
 
23
+ if __name__ == '__main__':
24
+ app.run()