duoquote commited on
Commit
0255e9b
1 Parent(s): 1238daf

Add address extraction functionality using BERT model

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ condaenv
README.md CHANGED
@@ -1,3 +1,98 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ base_model: dbmdz/bert-base-turkish-cased
4
+ tags:
5
+ - ner
6
+ - token-classification
7
+ - pytorch
8
+ - turkish
9
+ - tr
10
+ - dbmdz
11
+ - bert
12
+ - bert-base-cased
13
+ - bert-base-turkish-cased
14
+ widget:
15
+ - text: "Bağlarbaşı Mahallesi, Zübeyde Hanım Caddesi No: 10 / 3 34710 Üsküdar/İstanbul"
16
  ---
17
+
18
+ # address-extraction
19
+
20
+ ![Next Geography](https://nextgeography.com/wp-content/uploads/2022/02/next-geo-logo-1.png)
21
+
22
+ This is a simple library to extract addresses from text. The train.py file contains the code to train but is just included for reference, not to be run. The model is trained on our own dataset of addresses, which is not included in this repo. There is also predict.py which is a simple script to run the model on a single address.
23
+
24
+ The model is based on [dbmdz/bert-base-turkish-cased](https://huggingface.co/dbmdz/bert-base-turkish-cased) from [Hugging Face](https://huggingface.co/).
25
+
26
+ ## Example Results
27
+
28
+ ```
29
+ (g:\projects\address-extraction\venv) G:\projects\address-extraction>python predict.py
30
+ Osmangazi Mahallesi, Hoca Ahmet Yesevi Cd. No:34, 16050 Osmangazi/Bursa
31
+ Osmangazi Mahalle 98.65%
32
+ Hoca Ahmet Yesevi Cadde 97.63%
33
+ 34 Bina Numarası 98.92%
34
+ 16050 Posta Kodu 97.83%
35
+ Osmangazi İlçe 98.97%
36
+ Bursa İl 99.21%
37
+ Average Score: 0.9902257982053255
38
+ Labels Found: 6
39
+ ----------------------------------------------------------------------
40
+ Karşıyaka Mahallesi, Mavişehir Caddesi No: 91, Daire 4, 35540 Karşıyaka/İzmir
41
+ Karşıyaka Mahalle 99.11%
42
+ Mavişehir Cadde 97.16%
43
+ 91 Bina Numarası 98.73%
44
+ 4 Kat 29.06%
45
+ 35540 Posta Kodu 98.65%
46
+ Karşıyaka İlçe 99.17%
47
+ İzmir İl 99.16%
48
+ Average Score: 0.9237866433043229
49
+ Labels Found: 7
50
+ ----------------------------------------------------------------------
51
+ Selçuklu Mahallesi, Atatürk Bulvarı No: 55, 42050 Selçuklu/Konya
52
+ Selçuklu Mahalle 98.67%
53
+ Atatürk Cadde 57.06%
54
+ 55 Bina Numarası 98.94%
55
+ 42050 Posta Kodu 98.81%
56
+ Selçuklu İlçe 99.06%
57
+ Konya İl 99.22%
58
+ Average Score: 0.9659512996673584
59
+ Labels Found: 6
60
+ ----------------------------------------------------------------------
61
+ Alsancak Mahallesi, 1475. Sk. No:3, 35220 Konak/İzmir
62
+ Alsancak Mahalle 99.38%
63
+ 1475 Sokak 96.04%
64
+ 3 Bina Numarası 98.06%
65
+ 35220 Posta Kodu 98.75%
66
+ Konak İlçe 99.23%
67
+ İzmir İl 99.16%
68
+ Average Score: 0.9909308176291617
69
+ Labels Found: 6
70
+ ----------------------------------------------------------------------
71
+ Kocatepe Mahallesi, Yaşam Caddesi 3. Sokak No:4, 06420 Bayrampaşa/İstanbul
72
+ Kocatepe Mahalle 99.46%
73
+ Yaşam Cadde 94.07%
74
+ 3 Sokak 84.07%
75
+ 4 Bina Numarası 98.42%
76
+ 06420 Posta Kodu 98.54%
77
+ Bayrampaşa İlçe 98.97%
78
+ İstanbul İl 98.98%
79
+ Average Score: 0.9832726591511777
80
+ Labels Found: 7
81
+ ----------------------------------------------------------------------
82
+ ```
83
+
84
+ ## Installation & Usage
85
+
86
+ The environment.yml file contains the conda environment used to run the model. Environment is configured to use cuda enabled gpus but should work with no gpus too. To run the model, you can use the following commands:
87
+
88
+ ```bash
89
+ conda env create -f environment.yml -p ./condaenv
90
+ conda activate ./condaenv
91
+
92
+ python predict.py
93
+ ```
94
+
95
+
96
+ ## License
97
+
98
+ This project is licensed under the terms of the MIT license.
environment.yml ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: address-extraction
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - blas=2.120
9
+ - blas-devel=3.9.0
10
+ - brotli-python=1.1.0
11
+ - bzip2=1.0.8
12
+ - ca-certificates=2023.11.17
13
+ - certifi=2023.11.17
14
+ - charset-normalizer=3.3.2
15
+ - cuda-cccl=12.3.101
16
+ - cuda-cudart=12.1.105
17
+ - cuda-cudart-dev=12.1.105
18
+ - cuda-cupti=12.1.105
19
+ - cuda-libraries=12.1.0
20
+ - cuda-libraries-dev=12.1.0
21
+ - cuda-nvrtc=12.1.105
22
+ - cuda-nvrtc-dev=12.1.105
23
+ - cuda-nvtx=12.1.105
24
+ - cuda-opencl=12.3.101
25
+ - cuda-opencl-dev=12.3.101
26
+ - cuda-profiler-api=12.3.101
27
+ - cuda-runtime=12.1.0
28
+ - filelock=3.13.1
29
+ - freetype=2.12.1
30
+ - idna=3.6
31
+ - intel-openmp=2023.2.0
32
+ - jinja2=3.1.3
33
+ - lcms2=2.16
34
+ - lerc=4.0.0
35
+ - libblas=3.9.0
36
+ - libcblas=3.9.0
37
+ - libcublas=12.1.0.26
38
+ - libcublas-dev=12.1.0.26
39
+ - libcufft=11.0.2.4
40
+ - libcufft-dev=11.0.2.4
41
+ - libcurand=10.3.4.107
42
+ - libcurand-dev=10.3.4.107
43
+ - libcusolver=11.4.4.55
44
+ - libcusolver-dev=11.4.4.55
45
+ - libcusparse=12.0.2.55
46
+ - libcusparse-dev=12.0.2.55
47
+ - libdeflate=1.19
48
+ - libffi=3.4.2
49
+ - libhwloc=2.9.3
50
+ - libiconv=1.17
51
+ - libjpeg-turbo=3.0.0
52
+ - liblapack=3.9.0
53
+ - liblapacke=3.9.0
54
+ - libnpp=12.0.2.50
55
+ - libnpp-dev=12.0.2.50
56
+ - libnvjitlink=12.1.105
57
+ - libnvjitlink-dev=12.1.105
58
+ - libnvjpeg=12.1.1.14
59
+ - libnvjpeg-dev=12.1.1.14
60
+ - libpng=1.6.39
61
+ - libsqlite=3.44.2
62
+ - libtiff=4.6.0
63
+ - libuv=1.44.2
64
+ - libwebp-base=1.3.2
65
+ - libxcb=1.15
66
+ - libxml2=2.12.4
67
+ - libzlib=1.2.13
68
+ - m2w64-gcc-libgfortran=5.3.0
69
+ - m2w64-gcc-libs=5.3.0
70
+ - m2w64-gcc-libs-core=5.3.0
71
+ - m2w64-gmp=6.1.0
72
+ - m2w64-libwinpthread-git=5.0.0.4634.697f757
73
+ - markupsafe=2.1.4
74
+ - mkl=2023.2.0
75
+ - mkl-devel=2023.2.0
76
+ - mkl-include=2023.2.0
77
+ - mpmath=1.3.0
78
+ - msys2-conda-epoch=20160418
79
+ - networkx=3.2.1
80
+ - numpy=1.26.3
81
+ - openjpeg=2.5.0
82
+ - openssl=3.2.0
83
+ - pillow=10.2.0
84
+ - pip=23.3.2
85
+ - pthread-stubs=0.4
86
+ - pthreads-win32=2.9.1
87
+ - pysocks=1.7.1
88
+ - python=3.10.13
89
+ - python_abi=3.10
90
+ - pytorch=2.1.2
91
+ - pytorch-cuda=12.1
92
+ - pytorch-mutex=1.0
93
+ - pyyaml=6.0.1
94
+ - requests=2.31.0
95
+ - setuptools=69.0.3
96
+ - sympy=1.12
97
+ - tbb=2021.11.0
98
+ - tk=8.6.13
99
+ - typing_extensions=4.9.0
100
+ - ucrt=10.0.22621.0
101
+ - urllib3=2.1.0
102
+ - vc=14.3
103
+ - vc14_runtime=14.38.33130
104
+ - vs2015_runtime=14.38.33130
105
+ - wheel=0.42.0
106
+ - win_inet_pton=1.1.0
107
+ - xorg-libxau=1.0.11
108
+ - xorg-libxdmcp=1.1.3
109
+ - xz=5.2.6
110
+ - yaml=0.2.5
111
+ - zstd=1.5.5
112
+ - pip:
113
+ - accelerate==0.26.1
114
+ - aiohttp==3.9.1
115
+ - aiosignal==1.3.1
116
+ - async-timeout==4.0.3
117
+ - attrs==23.2.0
118
+ - colorama==0.4.6
119
+ - datasets==2.16.1
120
+ - dill==0.3.7
121
+ - frozenlist==1.4.1
122
+ - fsspec==2023.10.0
123
+ - huggingface-hub==0.20.3
124
+ - joblib==1.3.2
125
+ - multidict==6.0.4
126
+ - multiprocess==0.70.15
127
+ - orjson==3.9.12
128
+ - packaging==23.2
129
+ - pandas==2.2.0
130
+ - psutil==5.9.8
131
+ - pyarrow==15.0.0
132
+ - pyarrow-hotfix==0.6
133
+ - python-dateutil==2.8.2
134
+ - pytz==2023.3.post1
135
+ - regex==2023.12.25
136
+ - safetensors==0.4.1
137
+ - scikit-learn==1.4.0
138
+ - scipy==1.12.0
139
+ - six==1.16.0
140
+ - threadpoolctl==3.2.0
141
+ - tokenizers==0.15.1
142
+ - torchaudio==2.1.2
143
+ - torchvision==0.16.2
144
+ - tqdm==4.66.1
145
+ - transformers==4.37.0
146
+ - tzdata==2023.4
147
+ - xxhash==3.4.1
148
+ - yarl==1.9.4
labels.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"1": "B-\u00dclke", "2": "I-\u00dclke", "3": "B-\u0130l", "4": "I-\u0130l", "5": "B-\u0130l\u00e7e", "6": "I-\u0130l\u00e7e", "7": "B-Mahalle", "8": "I-Mahalle", "9": "B-Cadde", "10": "I-Cadde", "11": "B-Sokak", "12": "I-Sokak", "13": "B-Bina Ad\u0131", "14": "I-Bina Ad\u0131", "15": "B-Bina Numaras\u0131", "16": "I-Bina Numaras\u0131", "17": "B-Yer Ad\u0131", "18": "I-Yer Ad\u0131", "19": "B-Site", "20": "I-Site", "21": "B-Adres Detay", "22": "I-Adres Detay", "23": "B-Blok No", "24": "I-Blok No", "25": "B-Bulvar", "26": "I-Bulvar", "27": "B-Daire No", "28": "I-Daire No", "29": "B-Posta Kodu", "30": "I-Posta Kodu", "31": "B-Kat", "32": "I-Kat", "0": "O"}
model/config.json ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "dbmdz/bert-base-turkish-cased",
3
+ "architectures": [
4
+ "BertForTokenClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 768,
11
+ "id2label": {
12
+ "0": "LABEL_0",
13
+ "1": "LABEL_1",
14
+ "2": "LABEL_2",
15
+ "3": "LABEL_3",
16
+ "4": "LABEL_4",
17
+ "5": "LABEL_5",
18
+ "6": "LABEL_6",
19
+ "7": "LABEL_7",
20
+ "8": "LABEL_8",
21
+ "9": "LABEL_9",
22
+ "10": "LABEL_10",
23
+ "11": "LABEL_11",
24
+ "12": "LABEL_12",
25
+ "13": "LABEL_13",
26
+ "14": "LABEL_14",
27
+ "15": "LABEL_15",
28
+ "16": "LABEL_16",
29
+ "17": "LABEL_17",
30
+ "18": "LABEL_18",
31
+ "19": "LABEL_19",
32
+ "20": "LABEL_20",
33
+ "21": "LABEL_21",
34
+ "22": "LABEL_22",
35
+ "23": "LABEL_23",
36
+ "24": "LABEL_24",
37
+ "25": "LABEL_25",
38
+ "26": "LABEL_26",
39
+ "27": "LABEL_27",
40
+ "28": "LABEL_28",
41
+ "29": "LABEL_29",
42
+ "30": "LABEL_30",
43
+ "31": "LABEL_31",
44
+ "32": "LABEL_32"
45
+ },
46
+ "initializer_range": 0.02,
47
+ "intermediate_size": 3072,
48
+ "label2id": {
49
+ "LABEL_0": 0,
50
+ "LABEL_1": 1,
51
+ "LABEL_10": 10,
52
+ "LABEL_11": 11,
53
+ "LABEL_12": 12,
54
+ "LABEL_13": 13,
55
+ "LABEL_14": 14,
56
+ "LABEL_15": 15,
57
+ "LABEL_16": 16,
58
+ "LABEL_17": 17,
59
+ "LABEL_18": 18,
60
+ "LABEL_19": 19,
61
+ "LABEL_2": 2,
62
+ "LABEL_20": 20,
63
+ "LABEL_21": 21,
64
+ "LABEL_22": 22,
65
+ "LABEL_23": 23,
66
+ "LABEL_24": 24,
67
+ "LABEL_25": 25,
68
+ "LABEL_26": 26,
69
+ "LABEL_27": 27,
70
+ "LABEL_28": 28,
71
+ "LABEL_29": 29,
72
+ "LABEL_3": 3,
73
+ "LABEL_30": 30,
74
+ "LABEL_31": 31,
75
+ "LABEL_32": 32,
76
+ "LABEL_4": 4,
77
+ "LABEL_5": 5,
78
+ "LABEL_6": 6,
79
+ "LABEL_7": 7,
80
+ "LABEL_8": 8,
81
+ "LABEL_9": 9
82
+ },
83
+ "layer_norm_eps": 1e-12,
84
+ "max_position_embeddings": 512,
85
+ "model_type": "bert",
86
+ "num_attention_heads": 12,
87
+ "num_hidden_layers": 12,
88
+ "pad_token_id": 0,
89
+ "position_embedding_type": "absolute",
90
+ "torch_dtype": "float32",
91
+ "transformers_version": "4.37.0",
92
+ "type_vocab_size": 2,
93
+ "use_cache": true,
94
+ "vocab_size": 32000
95
+ }
model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9ad463f26bde0807766033059922cb64c633ca3cedd5637d98607a5a906eaeb
3
+ size 440231868
model/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "4": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": false,
48
+ "mask_token": "[MASK]",
49
+ "max_len": 512,
50
+ "model_max_length": 512,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52ec3649ec97e031b39c9b67ba30182a38c953393e011aa2fefb33da76b0ad9c
3
+ size 4664
model/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
predict.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import orjson
3
+ from transformers import pipeline
4
+ from transformers import BertTokenizerFast, AutoTokenizer
5
+
6
+ tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-turkish-cased")
7
+
8
+ with open("labels.json", "r") as f:
9
+ id_to_label = {int(k): v for k, v in orjson.loads(f.read()).items()}
10
+
11
+ nlp = pipeline(
12
+ "ner",
13
+ model="./model",
14
+ tokenizer=tokenizer,
15
+ device=0 if torch.cuda.is_available() else -1,
16
+ )
17
+
18
+ def get_entities(tokens):
19
+ entities = []
20
+ entity = None
21
+ for token in tokens:
22
+ label_id = int(token["entity"][6:])
23
+ label = id_to_label[label_id]
24
+ if label.startswith("B-"):
25
+ if entity:
26
+ entity["score"] /= entity["token_count"]
27
+ entities.append(entity)
28
+ entity = {
29
+ "label": label[2:],
30
+ "ranges": [token["start"], token["end"]],
31
+ "score": token["score"],
32
+ "token_count": 1,
33
+ }
34
+ elif label.startswith("I-"):
35
+ if entity and entity["label"] == label[2:]:
36
+ entity["ranges"][1] = token["end"]
37
+ entity["token_count"] += 1
38
+ entity["score"] += token["score"]
39
+ else:
40
+ if entity:
41
+ entity["ranges"][1] = token["end"]
42
+ entity["token_count"] += 1
43
+ entity["score"] += token["score"]
44
+ entity["score"] /= entity["token_count"]
45
+ entities.append(entity)
46
+ entity = None
47
+ else:
48
+ if entity:
49
+ entity["score"] /= entity["token_count"]
50
+ entities.append(entity)
51
+ entity = None
52
+ if entity:
53
+ entity["score"] /= entity["token_count"]
54
+ entities.append(entity)
55
+ return entities
56
+
57
+ def process(text):
58
+ nlp_output = nlp(text)
59
+ entities = get_entities(nlp_output)
60
+ for entity in entities:
61
+ print(f"{text[entity['ranges'][0]:entity['ranges'][1]]:<35} {entity['label']:>15} {entity['score'] * 100:.2f}%")
62
+ print("Average Score: ", sum([token["score"] for token in nlp_output]) / len(nlp_output))
63
+ print("Labels Found: ", len(entities))
64
+ print("-" * 70)
65
+
66
+ if __name__ == "__main__":
67
+ examples = [
68
+ "Osmangazi Mahallesi, Hoca Ahmet Yesevi Cd. No:34, 16050 Osmangazi/Bursa",
69
+ "Karşıyaka Mahallesi, Mavişehir Caddesi No: 91, Daire 4, 35540 Karşıyaka/İzmir",
70
+ "Selçuklu Mahallesi, Atatürk Bulvarı No: 55, 42050 Selçuklu/Konya",
71
+ "Alsancak Mahallesi, 1475. Sk. No:3, 35220 Konak/İzmir",
72
+ "Kocatepe Mahallesi, Yaşam Caddesi 3. Sokak No:4, 06420 Bayrampaşa/İstanbul",
73
+ ]
74
+ for example in examples:
75
+ print(example)
76
+ process(example)
77
+ while True:
78
+ text = input("Enter text: ")
79
+ if not text:
80
+ break
81
+ process(text)
train.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import requests
3
+ import json
4
+ import time
5
+ import torch
6
+ import orjson
7
+ import zipfile
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from transformers import BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments, BertConfig
11
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
12
+
13
+ API_URL = "http://dockerbase.duo:8000"
14
+ PROJECT_ID = 1
15
+
16
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
17
+
18
+ def load_data():
19
+
20
+ res = requests.post(
21
+ API_URL + "/v1/auth/login/",
22
+ json={"username": "admin", "password": "123"}
23
+ )
24
+ token = res.json()["key"]
25
+
26
+ res = requests.post(API_URL + "/v1/projects/1/download",
27
+ json={"format":"JSONL","exportApproved": True},
28
+ headers={"Authorization": "Token " + token}
29
+ )
30
+ task_id = res.json()["task_id"]
31
+
32
+
33
+ ready = False
34
+ print("Waiting for export task to be ready.", end="")
35
+ while not ready:
36
+ res = requests.get(
37
+ API_URL + "/v1/tasks/status/" + str(task_id),
38
+ headers={"Authorization": "Token " + token}
39
+ )
40
+ ready = res.json()["ready"]
41
+ if not ready:
42
+ time.sleep(1)
43
+ print(".", end="")
44
+ print("")
45
+
46
+ res = requests.get(
47
+ API_URL + f"/v1/projects/{PROJECT_ID}/download",
48
+ params={"taskId": task_id},
49
+ headers={"Authorization": "Token " + token}
50
+ )
51
+
52
+ zip_file = io.BytesIO(res.content)
53
+ with zipfile.ZipFile(zip_file, "r") as zip_ref:
54
+ data = zip_ref.read("admin.jsonl").decode("utf-8")
55
+
56
+ res = requests.get(
57
+ API_URL + f"/v1/projects/{PROJECT_ID}/span-types",
58
+ headers={"Authorization": "Token " + token}
59
+ )
60
+
61
+ labels = res.json()
62
+
63
+ return labels, [orjson.loads(line) for line in data.split("\n") if line]
64
+
65
+ labels, data = load_data()
66
+ label_to_id = {}
67
+ for i, label in enumerate(labels):
68
+ label_to_id["B-" + label["text"]] = i * 2 + 1
69
+ label_to_id["I-" + label["text"]] = i * 2 + 2
70
+ label_to_id["O"] = 0
71
+ id_to_label = {v: k for k, v in label_to_id.items()}
72
+
73
+ tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-turkish-cased")
74
+ model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-base-turkish-cased", num_labels=len(label_to_id)).to(device)
75
+
76
+
77
+ from datasets import DatasetDict, Dataset
78
+
79
+
80
+ def preprocess_data(item, tokenizer, label_to_id):
81
+ text = item['text']
82
+
83
+ inputs = tokenizer(
84
+ text,
85
+ return_offsets_mapping=True,
86
+ return_tensors="pt",
87
+ truncation=True,
88
+ padding='max_length',
89
+ max_length=128,
90
+ )
91
+
92
+ input_ids = inputs["input_ids"]
93
+ attention_mask = inputs["attention_mask"]
94
+ offset_mapping = inputs["offset_mapping"]
95
+
96
+ labels = ["O"] * 128
97
+ last_label = "O"
98
+ for token_idx, [off_start, off_end] in enumerate(offset_mapping[0]):
99
+ if off_start == off_end:
100
+ continue
101
+
102
+ for start, end, label in item['label']:
103
+ if start <= off_start and off_end <= end:
104
+ if last_label == label:
105
+ labels[token_idx] = "I-" + label
106
+ else:
107
+ labels[token_idx] = "B-" + label
108
+ last_label = label
109
+ break
110
+
111
+ # Convert labels to ids
112
+ labels = [label_to_id[label] for label in labels]
113
+
114
+ return {
115
+ "input_ids": input_ids.flatten(),
116
+ "attention_mask": attention_mask.flatten(),
117
+ "labels": labels,
118
+ }
119
+
120
+
121
+ class AddressDataset(Dataset):
122
+ def __init__(self, dataset):
123
+ self.dataset = dataset
124
+
125
+ def __len__(self):
126
+ return len(self.dataset)
127
+
128
+ def __getitem__(self, index):
129
+ item = self.dataset[index]
130
+ return {key: torch.tensor(val) for key, val in item.items()}
131
+
132
+
133
+
134
+ dataset = Dataset.from_generator(
135
+ lambda: (preprocess_data(item, tokenizer, label_to_id) for item in data),
136
+ )
137
+
138
+ dataset = dataset.train_test_split(test_size=0.2)
139
+ dataset = DatasetDict({
140
+ "train": dataset["train"],
141
+ "test": dataset["test"]
142
+ })
143
+
144
+
145
+ training_args = TrainingArguments(
146
+ output_dir="./results",
147
+ num_train_epochs=35,
148
+ per_device_train_batch_size=32,
149
+ per_device_eval_batch_size=32,
150
+ # logging_dir="./logs",
151
+ # logging_first_step=True,
152
+ # evaluation_strategy="epoch",
153
+ # save_strategy="epoch",
154
+ logging_strategy="epoch",
155
+ # load_best_model_at_end=True,
156
+ )
157
+
158
+ from sklearn.preprocessing import MultiLabelBinarizer
159
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
160
+
161
+ def compute_metrics(pred, id_to_label):
162
+ labels = pred.label_ids
163
+ preds = pred.predictions.argmax(-1)
164
+
165
+ labels = [[id_to_label[label_id] for label_id in label_ids] for label_ids in labels]
166
+ preds = [[id_to_label[pred] for pred in preds] for preds in preds]
167
+
168
+ labels = [label for label in labels if label != "O"]
169
+ preds = [pred for pred in preds if pred != "O"]
170
+
171
+ mlb = MultiLabelBinarizer()
172
+ mlb.fit([id_to_label.values()])
173
+ labels = mlb.transform(labels)
174
+ preds = mlb.transform(preds)
175
+
176
+ return {
177
+ "accuracy": accuracy_score(labels, preds),
178
+ "precision": precision_score(labels, preds, average="micro"),
179
+ "recall": recall_score(labels, preds, average="micro"),
180
+ "f1": f1_score(labels, preds, average="micro"),
181
+ }
182
+
183
+
184
+ trainer = Trainer(
185
+ model=model,
186
+ args=training_args,
187
+ train_dataset=dataset["train"],
188
+ eval_dataset=dataset["test"],
189
+ tokenizer=tokenizer,
190
+ compute_metrics=lambda p: compute_metrics(p, id_to_label),
191
+ )
192
+
193
+ trainer.train()
194
+ trainer.evaluate()
195
+
196
+ with open("./labels.json", "w") as f:
197
+ json.dump(id_to_label, f)
198
+
199
+ trainer.save_model("./model")