ankush-003 commited on
Commit
309d442
1 Parent(s): 7944b71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -1,12 +1,17 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
- from transformers import AutoTokenizer
4
- from transformers import TFAutoModelForSequenceClassification
5
- tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
 
 
 
 
 
6
 
7
  def predict(payload, malitious):
8
  inputs = tokenizer(payload, return_tensors="tf")
9
- model = TFAutoModelForSequenceClassification.from_pretrained("models/ankush-003/nosqli_identifier")
10
  logits = model(**inputs).logits
11
  predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
12
  # print(model.config.id2label[predicted_class_id])
 
1
  import gradio as gr
2
  import tensorflow as tf
3
+ # from transformers import AutoTokenizer
4
+ # from transformers import TFAutoModelForSequenceClassification
5
+ # tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
6
+ # Load model directly
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained("ankush-003/nosqli_identifier")
10
+ model = AutoModelForSequenceClassification.from_pretrained("ankush-003/nosqli_identifier")
11
 
12
  def predict(payload, malitious):
13
  inputs = tokenizer(payload, return_tensors="tf")
14
+ # model = TFAutoModelForSequenceClassification.from_pretrained("ankush-003/nosqli_identifier")
15
  logits = model(**inputs).logits
16
  predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
17
  # print(model.config.id2label[predicted_class_id])