File size: 1,196 Bytes
b54e71a
7944b71
309d442
 
e385089
309d442
0749cb7
309d442
e385089
 
99c7040
bb27896
 
 
309d442
6ef3ddb
 
 
bb27896
c5cc59e
bb27896
b54e71a
6ef3ddb
 
 
c5cc59e
6ef3ddb
93b0079
6ef3ddb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import gradio as gr
import tensorflow as tf
# from transformers import AutoTokenizer
# from transformers import TFAutoModelForSequenceClassification

# Load model directly
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

# tokenizer = AutoTokenizer.from_pretrained("ankush-003/nosqli_identifier")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = TFAutoModelForSequenceClassification.from_pretrained("ankush-003/nosqli_identifier")

def predict(payload, malitious):
    inputs = tokenizer(payload, return_tensors="tf")
    # model = TFAutoModelForSequenceClassification.from_pretrained("ankush-003/nosqli_identifier")
    logits = model(**inputs).logits
    predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
    # print(model.config.id2label[predicted_class_id])
    expected = "Malitious" if malitious else "Benign"
    
    return model.config.id2label[predicted_class_id], expected 

demo = gr.Interface(
    fn=predict,
    inputs=["text","checkbox"],
    outputs=[gr.Textbox(label="Model Prediction"),gr.Textbox(label="Expected")]
)
demo.launch(debug=True)
# gr.Interface.load("models/ankush-003/nosqli_identifier").launch()