serdarakyol commited on
Commit
f265d17
1 Parent(s): d0cc15f

added tensorflow usage

Browse files
Files changed (1) hide show
  1. README.md +24 -4
README.md CHANGED
@@ -9,7 +9,7 @@ The dataset downloaded from interpress. This dataset is real world data. Actuall
9
  ## Model
10
  Model accuracy on train data and validation data is %97.
11
 
12
- ## Usage
13
  ```sh
14
  pip install transformers or pip install transformers==4.3.3
15
  ```
@@ -17,7 +17,6 @@ pip install transformers or pip install transformers==4.3.3
17
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
18
 
19
  tokenizer = AutoTokenizer.from_pretrained("serdarakyol/interpress-turkish-news-classification")
20
-
21
  model = AutoModelForSequenceClassification.from_pretrained("serdarakyol/interpress-turkish-news-classification")
22
  ```
23
 
@@ -44,7 +43,7 @@ def prediction(news):
44
  return_attention_mask=True,
45
  padding='max_length',
46
  truncation=True,
47
- return_tensors='pt') # for tf tensors, switch pt to tf
48
 
49
  inputs = indices["input_ids"].clone().detach().to(device)
50
  masks = indices["attention_mask"].clone().detach().to(device)
@@ -78,7 +77,28 @@ pred = prediction(news)
78
  print(labels[pred])
79
  # > World
80
  ```
81
- Thanks to @yavuzkomecoglu for contributes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  If you have any question, please, don't hesitate to contact with me
84
  [![linkedin](https://img.shields.io/badge/LinkedIn-0077B5?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/serdarakyol55/)
 
9
  ## Model
10
  Model accuracy on train data and validation data is %97.
11
 
12
+ ## Usage for Torch
13
  ```sh
14
  pip install transformers or pip install transformers==4.3.3
15
  ```
 
17
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
18
 
19
  tokenizer = AutoTokenizer.from_pretrained("serdarakyol/interpress-turkish-news-classification")
 
20
  model = AutoModelForSequenceClassification.from_pretrained("serdarakyol/interpress-turkish-news-classification")
21
  ```
22
 
 
43
  return_attention_mask=True,
44
  padding='max_length',
45
  truncation=True,
46
+ return_tensors='pt')
47
 
48
  inputs = indices["input_ids"].clone().detach().to(device)
49
  masks = indices["attention_mask"].clone().detach().to(device)
 
77
  print(labels[pred])
78
  # > World
79
  ```
80
+ ## Usage for Tensorflow
81
+ ```sh
82
+ pip install transformers or pip install transformers==4.3.3
83
+
84
+ import tensorflow as tf
85
+ from transformers import BertTokenizer, TFBertForSequenceClassification
86
+ import numpy as np
87
+
88
+ tokenizer = BertTokenizer.from_pretrained('serdarakyol/interpress-turkish-news-classification')
89
+ model = TFBertForSequenceClassification.from_pretrained("serdarakyol/interpress-turkish-news-classification")
90
+
91
+ inputs = tokenizer(news, return_tensors="tf")
92
+ inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
93
+
94
+ outputs = model(inputs)
95
+ loss = outputs.loss
96
+ logits = outputs.logits
97
+ pred = np.argmax(logits,axis=1)[0]
98
+ labels[pred]
99
+ # > World
100
+ ```
101
+ Thanks to [@yavuzkomecoglu](https://huggingface.co/yavuzkomecoglu) for contributes
102
 
103
  If you have any question, please, don't hesitate to contact with me
104
  [![linkedin](https://img.shields.io/badge/LinkedIn-0077B5?style=for-the-badge&logo=linkedin&logoColor=white)](https://www.linkedin.com/in/serdarakyol55/)