serdarakyol
commited on
Commit
•
f265d17
1
Parent(s):
d0cc15f
added tensorflow usage
Browse files
README.md
CHANGED
@@ -9,7 +9,7 @@ The dataset downloaded from interpress. This dataset is real world data. Actuall
|
|
9 |
## Model
|
10 |
Model accuracy on train data and validation data is %97.
|
11 |
|
12 |
-
## Usage
|
13 |
```sh
|
14 |
pip install transformers or pip install transformers==4.3.3
|
15 |
```
|
@@ -17,7 +17,6 @@ pip install transformers or pip install transformers==4.3.3
|
|
17 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
18 |
|
19 |
tokenizer = AutoTokenizer.from_pretrained("serdarakyol/interpress-turkish-news-classification")
|
20 |
-
|
21 |
model = AutoModelForSequenceClassification.from_pretrained("serdarakyol/interpress-turkish-news-classification")
|
22 |
```
|
23 |
|
@@ -44,7 +43,7 @@ def prediction(news):
|
|
44 |
return_attention_mask=True,
|
45 |
padding='max_length',
|
46 |
truncation=True,
|
47 |
-
return_tensors='pt')
|
48 |
|
49 |
inputs = indices["input_ids"].clone().detach().to(device)
|
50 |
masks = indices["attention_mask"].clone().detach().to(device)
|
@@ -78,7 +77,28 @@ pred = prediction(news)
|
|
78 |
print(labels[pred])
|
79 |
# > World
|
80 |
```
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
If you have any question, please, don't hesitate to contact with me
|
84 |
[![linkedin](https://img.shields.io/badge/LinkedIn-0077B5?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/serdarakyol55/)
|
|
|
9 |
## Model
|
10 |
Model accuracy on train data and validation data is %97.
|
11 |
|
12 |
+
## Usage for Torch
|
13 |
```sh
|
14 |
pip install transformers or pip install transformers==4.3.3
|
15 |
```
|
|
|
17 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
18 |
|
19 |
tokenizer = AutoTokenizer.from_pretrained("serdarakyol/interpress-turkish-news-classification")
|
|
|
20 |
model = AutoModelForSequenceClassification.from_pretrained("serdarakyol/interpress-turkish-news-classification")
|
21 |
```
|
22 |
|
|
|
43 |
return_attention_mask=True,
|
44 |
padding='max_length',
|
45 |
truncation=True,
|
46 |
+
return_tensors='pt')
|
47 |
|
48 |
inputs = indices["input_ids"].clone().detach().to(device)
|
49 |
masks = indices["attention_mask"].clone().detach().to(device)
|
|
|
77 |
print(labels[pred])
|
78 |
# > World
|
79 |
```
|
80 |
+
## Usage for Tensorflow
|
81 |
+
```sh
|
82 |
+
pip install transformers or pip install transformers==4.3.3
|
83 |
+
|
84 |
+
import tensorflow as tf
|
85 |
+
from transformers import BertTokenizer, TFBertForSequenceClassification
|
86 |
+
import numpy as np
|
87 |
+
|
88 |
+
tokenizer = BertTokenizer.from_pretrained('serdarakyol/interpress-turkish-news-classification')
|
89 |
+
model = TFBertForSequenceClassification.from_pretrained("serdarakyol/interpress-turkish-news-classification")
|
90 |
+
|
91 |
+
inputs = tokenizer(news, return_tensors="tf")
|
92 |
+
inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
|
93 |
+
|
94 |
+
outputs = model(inputs)
|
95 |
+
loss = outputs.loss
|
96 |
+
logits = outputs.logits
|
97 |
+
pred = np.argmax(logits,axis=1)[0]
|
98 |
+
labels[pred]
|
99 |
+
# > World
|
100 |
+
```
|
101 |
+
Thanks to [@yavuzkomecoglu](https://huggingface.co/yavuzkomecoglu) for contributes
|
102 |
|
103 |
If you have any question, please, don't hesitate to contact with me
|
104 |
[![linkedin](https://img.shields.io/badge/LinkedIn-0077B5?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/serdarakyol55/)
|