tomato commited on
Commit
4db7dc4
1 Parent(s): e6eb31e

fix coding error

Browse files
Files changed (3) hide show
  1. app.py +3 -9
  2. data_utils.py +0 -319
  3. tokenizers_pegasus.py +0 -599
app.py CHANGED
@@ -1,20 +1,14 @@
1
  import gradio as gr
2
  import torch
3
  from tqdm import tqdm
4
- from transformers import PegasusForConditionalGeneration
5
- from tokenizers_pegasus import PegasusTokenizer
6
 
7
  MODEL_NAME = "csebuetnlp/mT5_multilingual_XLSum"
8
 
9
- summarizer = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME)
10
- tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME)
11
 
12
  def summarize(text):
13
- inputs = tokenizer(text, max_length=1024, return_tensors="pt")
14
-
15
- # Generate Summary
16
- summary_ids = summarizer.generate(inputs["input_ids"])
17
- return tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
18
 
19
  demo = gr.Blocks(title="⭐ Summ4rizer ⭐")
20
  demo.encrypt = False
 
1
  import gradio as gr
2
  import torch
3
  from tqdm import tqdm
4
+ from transformers import pipeline
 
5
 
6
  MODEL_NAME = "csebuetnlp/mT5_multilingual_XLSum"
7
 
8
+ summarizer = pipeline(task="summarization", model=MODEL_NAME)
 
9
 
10
  def summarize(text):
11
+ return summarizer(text)
 
 
 
 
12
 
13
  demo = gr.Blocks(title="⭐ Summ4rizer ⭐")
14
  demo.encrypt = False
data_utils.py DELETED
@@ -1,319 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- import re
4
- import six
5
- import unicodedata
6
- import torch
7
- import rouge
8
- import numpy as np
9
- import random
10
- # from fengshen.examples.pegasus.pegasus_utils import text_segmentate
11
- import sys
12
-
13
- sys.path.append('../../../')
14
-
15
- rouge = rouge.Rouge()
16
-
17
-
18
- is_py2 = six.PY2
19
-
20
- if not is_py2:
21
- basestring = str
22
-
23
-
24
- def _is_chinese_char(cp):
25
- """Checks whether CP is the codepoint of a CJK character."""
26
- # This defines a "chinese character" as anything in the CJK Unicode block:
27
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
28
- #
29
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
30
- # despite its name. The modern Korean Hangul alphabet is a different block,
31
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
32
- # space-separated words, so they are not treated specially and handled
33
- # like the all of the other languages.
34
- if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
35
- or (cp >= 0x20000 and cp <= 0x2A6DF)
36
- or (cp >= 0x2A700 and cp <= 0x2B73F)
37
- or (cp >= 0x2B740 and cp <= 0x2B81F)
38
- or (cp >= 0x2B820 and cp <= 0x2CEAF)
39
- or (cp >= 0xF900 and cp <= 0xFAFF)
40
- or (cp >= 0x2F800 and cp <= 0x2FA1F)):
41
- return True
42
-
43
- return False
44
-
45
-
46
- def _is_whitespace(char):
47
- """Checks whether `char` is a whitespace character."""
48
- # \t, \n, and \r are technically control characters but we treat them
49
- # as whitespace since they are generally considered as such.
50
- if char == " " or char == "\t" or char == "\n" or char == "\r":
51
- return True
52
- cat = unicodedata.category(char)
53
- if cat == "Zs":
54
- return True
55
- return False
56
-
57
-
58
- def _is_control(char):
59
- """Checks whether `char` is a control character."""
60
- # These are technically control characters but we count them as whitespace
61
- # characters.
62
- if char == "\t" or char == "\n" or char == "\r":
63
- return False
64
- cat = unicodedata.category(char)
65
- if cat.startswith("C"):
66
- return True
67
- return False
68
-
69
-
70
- def _is_punctuation(char):
71
- """Checks whether `char` is a punctuation character."""
72
- cp = ord(char)
73
- # We treat all non-letter/number ASCII as punctuation.
74
- # Characters such as "^", "$", and "`" are not in the Unicode
75
- # Punctuation class but we treat them as punctuation anyways, for
76
- # consistency.
77
- if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (
78
- cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
79
- return True
80
- cat = unicodedata.category(char)
81
- if cat.startswith("P"):
82
- return True
83
- return False
84
-
85
-
86
- def is_string(s):
87
- """判断是否是字符串
88
- """
89
- return isinstance(s, basestring)
90
-
91
-
92
- def is_stopwords(word, stopwords):
93
- if word in stopwords:
94
- return True
95
- else:
96
- return False
97
-
98
-
99
- def text_segmentate(text):
100
- en_seg_pattern = '((?:\\!|\\?|\\.|\\n)+(?:\\s)+)'
101
- ch_seg_pattern = '((?:?|!|。|\\n)+)'
102
- try:
103
- text = re.sub(en_seg_pattern, r'\1[SEP]', text)
104
- # print("sub text: ", text)
105
- except Exception as e:
106
- print("input: ", text)
107
- raise e
108
- text = re.sub(ch_seg_pattern, r'\1[SEP]', text)
109
- # print("sub ch text: ", text)
110
- text_list = text.split("[SEP]")
111
- text_list = list(filter(lambda x: len(x) != 0, text_list))
112
- return text_list
113
-
114
-
115
- def load_stopwords(stopwords_path):
116
- stopwords_dict = {}
117
- with open(stopwords_path, "r") as rf:
118
- for line in rf:
119
- line = line.strip()
120
- if line not in stopwords_dict:
121
- stopwords_dict[line] = 0
122
- else:
123
- pass
124
- return stopwords_dict
125
-
126
-
127
- def text_process(text, max_length):
128
- """分割文本
129
- """
130
- texts = text_segmentate(text)
131
-
132
- result, length = [], 0
133
- for text in texts:
134
- if length + len(text) > max_length * 1.3 and len(result) >= 3:
135
- yield result
136
- result, length = [], 0
137
- result.append(text)
138
- length += len(text)
139
- if result and len(result) >= 3:
140
- yield result
141
-
142
-
143
- def text_process_split_long_content(text, max_length):
144
- """分割长文本
145
- """
146
- texts = text_segmentate(text)
147
-
148
- result, sentence_num = "", 0
149
- for text in texts:
150
- if len(text) > 500:
151
- if len(result) > 300 and sentence_num >= 3:
152
- yield result
153
- result, sentence_num = "", 0
154
- else:
155
- result, sentence_num = "", 0
156
- continue
157
- else:
158
- if len(result) + len(text) > max_length * 1.1 and sentence_num >= 3:
159
- yield result
160
- result, sentence_num = "", 0
161
- result += text
162
- sentence_num += 1
163
-
164
- if result and sentence_num >= 3:
165
- yield result
166
-
167
-
168
- def gather_join(texts, idxs):
169
- """取出对应的text,然后拼接起来
170
- """
171
- return ''.join([texts[i] for i in idxs])
172
-
173
-
174
- def gather_join_f1(texts_token, idsx):
175
- join_texts = []
176
- for id in idsx:
177
- join_texts.extend(texts_token[id])
178
- return join_texts
179
-
180
-
181
- def compute_rouge(source, target):
182
- """计算rouge-1、rouge-2、rouge-l
183
- """
184
- source, target = ' '.join(source), ' '.join(target)
185
- try:
186
- scores = rouge.get_scores(hyps=source, refs=target)
187
- return {
188
- 'rouge-1': scores[0]['rouge-1']['f'],
189
- 'rouge-2': scores[0]['rouge-2']['f'],
190
- 'rouge-l': scores[0]['rouge-l']['f'],
191
- }
192
- except ValueError:
193
- return {
194
- 'rouge-1': 0.0,
195
- 'rouge-2': 0.0,
196
- 'rouge-l': 0.0,
197
- }
198
-
199
-
200
- def remove_stopwords(texts, stopwords_dict):
201
- for i, text in enumerate(texts):
202
- texts[i] = list(filter(lambda x: x not in stopwords_dict, text))
203
- return texts
204
-
205
-
206
- def pseudo_summary_f1(texts,
207
- stopwords,
208
- tokenizer,
209
- max_length,
210
- rouge_strategy="rouge-l"):
211
- """构建伪标签摘要数据集
212
- """
213
- summary_rate = 0.25
214
- max_length = max_length - 1
215
- texts_tokens = []
216
- sentece_idxs_vec = []
217
- for text in texts:
218
- if len(texts) == 0:
219
- continue
220
- try:
221
- ids = tokenizer.encode(text.strip())[:-1]
222
- except ValueError:
223
- print("error, input : ", text)
224
- raise ValueError
225
- sentece_idxs_vec.append(ids)
226
- tokens = [tokenizer._convert_id_to_token(token) for token in ids]
227
- texts_tokens.append(tokens)
228
-
229
- texts_tokens_rm = remove_stopwords(texts_tokens, stopwords)
230
- source_idxs, target_idxs = list(range(len(texts))), []
231
-
232
- assert len(texts_tokens) == len(texts)
233
- # truncate_index = 0
234
- while True:
235
- sims = []
236
- for i in source_idxs:
237
- new_source_idxs = [j for j in source_idxs if j != i]
238
- new_target_idxs = sorted(target_idxs + [i])
239
- new_source = gather_join_f1(texts_tokens_rm, new_source_idxs)
240
- new_target = gather_join_f1(texts_tokens_rm, new_target_idxs)
241
- sim = compute_rouge(new_source, new_target)[rouge_strategy]
242
- sims.append(sim)
243
- new_idx = source_idxs[np.argmax(sims)]
244
- del sims
245
- source_idxs.remove(new_idx)
246
- target_idxs = sorted(target_idxs + [new_idx])
247
- source = gather_join(texts, source_idxs)
248
- target = gather_join(texts, target_idxs)
249
- try:
250
- if (len(source_idxs) == 1
251
- or 1.0 * len(target) / len(source) > summary_rate):
252
- break
253
- except ZeroDivisionError as e:
254
- print(e.meesage)
255
- print(texts)
256
- print("source: ", source)
257
- print("target: ", target)
258
-
259
- if len(source) < len(target):
260
- source, target = target, source
261
- source_idxs, target_idxs = target_idxs, source_idxs
262
-
263
- return sentece_idxs_vec, source, target, source_idxs, target_idxs
264
-
265
-
266
- def get_input_mask(sentence_id_vec, indexs):
267
- target_idxs = []
268
- input_idxs = []
269
- kMaskSentenceTokenId = 2
270
- kEosTokenId = 1
271
- mask_sentence_options_cumulative_prob = [0.9, 0.9, 1, 1]
272
- for index in indexs:
273
- target_idxs.extend(sentence_id_vec[index])
274
- choice = random.uniform(0, 1)
275
- if choice < mask_sentence_options_cumulative_prob[0]:
276
- # print("mask index: ", index)
277
- sentence_id_vec[index] = [kMaskSentenceTokenId]
278
- elif choice < mask_sentence_options_cumulative_prob[1]:
279
- # print("replace index: ", index)
280
- replace_id = random.randint(0, len(sentence_id_vec))
281
- sentence_id_vec[index] = sentence_id_vec[replace_id]
282
- elif choice < mask_sentence_options_cumulative_prob[2]:
283
- pass
284
- else:
285
- sentence_id_vec[index] = []
286
-
287
- target_idxs.append(kEosTokenId)
288
- # print(sentence_id_vec)
289
- for index, sentence_id in enumerate(sentence_id_vec):
290
- # print(index, sentence_id)
291
- if len(sentence_id) == 0:
292
- continue
293
- input_idxs.extend(sentence_id_vec[index])
294
-
295
- input_idxs.append(kEosTokenId)
296
- return input_idxs, target_idxs
297
-
298
-
299
- def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int,
300
- decoder_start_token_id: int):
301
- """
302
- Shift input ids one token to the right.
303
- """
304
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
305
- shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
306
- shifted_input_ids[:, 0] = decoder_start_token_id
307
-
308
- if pad_token_id is None:
309
- raise ValueError("self.model.config.pad_token_id has to be defined.")
310
- # replace possible -100 values in labels by `pad_token_id`
311
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
312
-
313
- return shifted_input_ids
314
-
315
-
316
- def padding_to_maxlength(ids, max_length, pad_id):
317
- cur_len = len(ids)
318
- len_diff = max_length - cur_len
319
- return ids + [pad_id] * len_diff, [1] * cur_len + [0] * len_diff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizers_pegasus.py DELETED
@@ -1,599 +0,0 @@
1
-
2
- from fengshen.examples.pegasus.data_utils import (
3
- _is_control,
4
- _is_punctuation,
5
- _is_whitespace,
6
- _is_chinese_char)
7
- from transformers import PreTrainedTokenizer
8
- from transformers import logging
9
- from typing import List, Optional, Tuple, Union
10
- import collections
11
- import os
12
- import unicodedata
13
- import re
14
- import jieba
15
- import sys
16
-
17
- sys.path.append("../../../../")
18
-
19
- jieba.dt.tmp_dir = os.path.expanduser(
20
- "/cognitive_comp/dongxiaoqun/software/jieba/tmp/")
21
- # jieba.enable_parallel(8)
22
- jieba.initialize()
23
-
24
- logger = logging.get_logger(__name__)
25
-
26
- VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
27
-
28
-
29
- def load_vocab(vocab_file):
30
- """Loads a vocabulary file into a dictionary."""
31
- vocab = collections.OrderedDict()
32
- with open(vocab_file, "r", encoding="utf-8") as reader:
33
- tokens = reader.readlines()
34
- for index, token in enumerate(tokens):
35
- token = token.rstrip("\n")
36
- vocab[token] = index
37
- return vocab
38
-
39
-
40
- def whitespace_tokenize(text):
41
- """Runs basic whitespace cleaning and splitting on a piece of text."""
42
- text = text.strip()
43
- if not text:
44
- return []
45
- tokens = text.split()
46
- return tokens
47
-
48
-
49
- class PegasusTokenizer(PreTrainedTokenizer):
50
- # copy from BertTokenizer
51
- r"""
52
- Construct a Pegasus tokenizer. Based on WordPiece.
53
- This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
54
- this superclass for more information regarding those methods.
55
- Args:
56
- vocab_file (`str`):
57
- File containing the vocabulary.
58
- do_lower_case (`bool`, *optional*, defaults to `True`):
59
- Whether or not to lowercase the input when tokenizing.
60
- do_basic_tokenize (`bool`, *optional*, defaults to `True`):
61
- Whether or not to do basic tokenization before WordPiece.
62
- never_split (`Iterable`, *optional*):
63
- Collection of tokens which will never be split during tokenization. Only has an effect when
64
- `do_basic_tokenize=True`
65
- unk_token (`str`, *optional*, defaults to `"[UNK]"`):
66
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
67
- token instead.
68
- sep_token (`str`, *optional*, defaults to `"[SEP]"`):
69
- The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
70
- sequence classification or for a text and a question for question answering. It is also used as the last
71
- token of a sequence built with special tokens.
72
- pad_token (`str`, *optional*, defaults to `"[PAD]"`):
73
- The token used for padding, for example when batching sequences of different lengths.
74
- cls_token (`str`, *optional*, defaults to `"[CLS]"`):
75
- The classifier token which is used when doing sequence classification (classification of the whole sequence
76
- instead of per-token classification). It is the first token of the sequence when built with special tokens.
77
- mask_token (`str`, *optional*, defaults to `"[MASK]"`):
78
- The token used for masking values. This is the token used when training this model with masked language
79
- modeling. This is the token which the model will try to predict.
80
- tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
81
- Whether or not to tokenize Chinese characters.
82
- This should likely be deactivated for Japanese (see this
83
- [issue](https://github.com/huggingface/transformers/issues/328)).
84
- strip_accents (`bool`, *optional*):
85
- Whether or not to strip all accents. If this option is not specified, then it will be determined by the
86
- value for `lowercase` (as in the original BERT).
87
- """
88
-
89
- vocab_files_names = VOCAB_FILES_NAMES
90
- model_input_names = ["input_ids", "attention_mask"]
91
-
92
- # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
93
- # pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
94
- # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
95
-
96
- def __init__(self,
97
- vocab_file,
98
- do_lower_case=True,
99
- do_basic_tokenize=True,
100
- never_split=None,
101
- pad_token="<pad>",
102
- eos_token="</s>",
103
- unk_token="<unk>",
104
- mask_token="<mask_2>",
105
- mask_token_sent="<mask_1>",
106
- additional_special_tokens=None,
107
- sep_token="[SEP]",
108
- cls_token="[CLS]",
109
- tokenize_chinese_chars=True,
110
- strip_accents=None,
111
- offset=100,
112
- pre_tokenizer=lambda x: jieba.cut(x, HMM=False),
113
- **kwargs):
114
- self.offset = offset
115
-
116
- if additional_special_tokens is not None:
117
- if not isinstance(additional_special_tokens, list):
118
- raise TypeError(
119
- f"additional_special_tokens should be of type {type(list)}, \
120
- but is {type(additional_special_tokens)}"
121
- )
122
-
123
- additional_special_tokens_extended = (
124
- ([mask_token_sent] + additional_special_tokens)
125
- if mask_token_sent not in additional_special_tokens
126
- and mask_token_sent is not None else additional_special_tokens)
127
-
128
- # fill additional tokens with ..., <unk_token_102> in case not all additional tokens are already taken
129
- additional_special_tokens_extended += [
130
- f"<unk_{i}>" for i in range(
131
- len(additional_special_tokens_extended), self.offset - 1)
132
- ]
133
-
134
- if len(set(additional_special_tokens_extended)) != len(
135
- additional_special_tokens_extended):
136
- raise ValueError(
137
- f"Please make sure that the provided additional_special_tokens \
138
- do not contain an incorrectly shifted list of <unk_x> tokens. \
139
- Found {additional_special_tokens_extended}."
140
- )
141
- additional_special_tokens = additional_special_tokens_extended
142
- else:
143
- additional_special_tokens = [
144
- mask_token_sent
145
- ] if mask_token_sent is not None else []
146
- # additional_special_tokens += [f"<unk_{i}>" for i in range(3, self.offset)]
147
-
148
- # print("additional_special_tokens: ", additional_special_tokens)
149
-
150
- if not os.path.isfile(vocab_file):
151
- raise ValueError(
152
- f"Can't find a vocabulary file at path '{vocab_file}'. \
153
- To load the vocabulary from a Google pretrained "
154
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
155
- )
156
-
157
- super().__init__(
158
- do_lower_case=do_lower_case,
159
- do_basic_tokenize=do_basic_tokenize,
160
- never_split=never_split,
161
- unk_token=unk_token,
162
- sep_token=sep_token,
163
- pad_token=pad_token,
164
- cls_token=cls_token,
165
- mask_token=mask_token,
166
- eos_token=eos_token,
167
- tokenize_chinese_chars=tokenize_chinese_chars,
168
- additional_special_tokens=additional_special_tokens,
169
- strip_accents=strip_accents,
170
- **kwargs,
171
- )
172
-
173
- self.pre_tokenizer = pre_tokenizer
174
- self.mask_token_sent = mask_token_sent
175
- self.vocab = load_vocab(vocab_file)
176
-
177
- self.vocab[self.eos_token] = self.vocab.pop("[unused1]")
178
- # self.vocab[self.eos_token] = self.vocab.pop("[unused2]")
179
- self.vocab[self.pad_token] = self.vocab.pop("[PAD]")
180
- self.vocab[self.unk_token] = self.vocab.pop("[UNK]")
181
-
182
- if self.mask_token_sent is not None:
183
- self.vocab[self.mask_token] = self.vocab.pop("[unused3]")
184
- self.vocab[self.mask_token_sent] = self.vocab.pop("[unused2]")
185
-
186
- self.ids_to_tokens = collections.OrderedDict([
187
- (ids, tok) for tok, ids in self.vocab.items()
188
- ])
189
- self.do_basic_tokenize = do_basic_tokenize
190
- if do_basic_tokenize:
191
- self.basic_tokenizer = BasicTokenizer(
192
- do_lower_case=do_lower_case,
193
- never_split=never_split,
194
- tokenize_chinese_chars=tokenize_chinese_chars,
195
- strip_accents=strip_accents,
196
- )
197
- self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,
198
- unk_token=self.unk_token)
199
-
200
- @property
201
- def do_lower_case(self):
202
- return self.basic_tokenizer.do_lower_case
203
-
204
- @property
205
- def vocab_size(self):
206
- return len(self.vocab)
207
-
208
- def get_vocab(self):
209
- return dict(self.vocab, **self.added_tokens_encoder)
210
-
211
- def _tokenize(self, text):
212
- split_tokens = []
213
- # print("pegasus_tokenizer: ", text)
214
- for text in self.pre_tokenizer(text):
215
- if text in self.vocab:
216
- split_tokens.append(text)
217
- else:
218
- if self.do_basic_tokenize:
219
- for token in self.basic_tokenizer.tokenize(
220
- text, never_split=self.all_special_tokens):
221
-
222
- # If the token is part of the never_split set
223
- if token in self.basic_tokenizer.never_split:
224
- split_tokens.append(token)
225
- else:
226
- split_tokens += self.wordpiece_tokenizer.tokenize(
227
- token)
228
- else:
229
- split_tokens = self.wordpiece_tokenizer.tokenize(text)
230
- return split_tokens
231
-
232
- def _convert_token_to_id(self, token):
233
- """Converts a token (str) in an id using the vocab."""
234
- return self.vocab.get(token, self.vocab.get(self.unk_token))
235
-
236
- def _convert_id_to_token(self, index):
237
- """Converts an index (integer) in a token (str) using the vocab."""
238
- return self.ids_to_tokens.get(index, self.unk_token)
239
-
240
- @staticmethod
241
- def _cjk_punctuation():
242
- return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\
243
- \uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\
244
- \uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\
245
- \u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\
246
- \u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002'
247
-
248
- def convert_ids_to_tokens(
249
- self,
250
- ids: Union[int, List[int]],
251
- skip_special_tokens: bool = False) -> Union[str, List[str]]:
252
- """
253
- Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
254
- added tokens.
255
- Args:
256
- ids (`int` or `List[int]`):
257
- The token id (or token ids) to convert to tokens.
258
- skip_special_tokens (`bool`, *optional*, defaults to `False`):
259
- Whether or not to remove special tokens in the decoding.
260
- Returns:
261
- `str` or `List[str]`: The decoded token(s).
262
- """
263
- if isinstance(ids, int):
264
- if ids in self.added_tokens_decoder:
265
- return self.added_tokens_decoder[ids]
266
- else:
267
- return self._convert_id_to_token(ids)
268
- tokens = []
269
- for index in ids:
270
- index = int(index)
271
- if skip_special_tokens and index in self.all_special_ids and index != 2:
272
- continue
273
- if index in self.added_tokens_decoder:
274
- tokens.append(self.added_tokens_decoder[index])
275
- else:
276
- tokens.append(self._convert_id_to_token(index))
277
- return tokens
278
-
279
- def convert_tokens_to_string(self, tokens):
280
- """Converts a sequence of tokens (string) in a single string."""
281
- # for token in
282
- # tokens = tokens or self.ids_to_tokens(ids)
283
- # tokens = [token for token in tokens if not self._is_special(token)]
284
-
285
- text = ''
286
- for i, token in enumerate(tokens):
287
- if token[:2] == '##':
288
- text += token[2:]
289
- elif len(token) == 1 and _is_chinese_char(ord(token)):
290
- text += token
291
- elif len(token) == 1 and _is_punctuation(token):
292
- text += token
293
- text += ' '
294
- elif i > 0 and _is_chinese_char(ord(text[-1])):
295
- text += token
296
- elif tokens == "</s>":
297
- continue
298
- else:
299
- text += ' '
300
- text += token
301
-
302
- text = re.sub(' +', ' ', text)
303
- text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text)
304
- punctuation = re.sub(' +', '', self._cjk_punctuation()).strip() + '+-/={(<['
305
- punctuation_regex = '|'.join([re.escape(p) for p in punctuation])
306
- punctuation_regex = '(%s) ' % punctuation_regex
307
- text = re.sub(punctuation_regex, '\\1', text)
308
- text = re.sub(r'(\d\.) (\d)', '\\1\\2', text)
309
-
310
- return text.strip()
311
- # out_string = " ".join(tokens).replace(" ##", "").strip()
312
-
313
- def build_inputs_with_special_tokens(
314
- self,
315
- token_ids_0: List[int],
316
- token_ids_1: Optional[List[int]] = None) -> List[int]:
317
- """
318
- Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
319
- and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence:
320
- - single sequence: `X </s>`
321
- - pair of sequences: `A B </s>` (not intended use)
322
- BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
323
- separator.
324
- Args:
325
- token_ids_0 (`List[int]`):
326
- List of IDs to which the special tokens will be added.
327
- token_ids_1 (`List[int]`, *optional*):
328
- Optional second list of IDs for sequence pairs.
329
- Returns:
330
- `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
331
- """
332
- if token_ids_1 is None:
333
- return token_ids_0 + [self.eos_token_id]
334
- return token_ids_0 + token_ids_1 + [self.eos_token_id]
335
-
336
- def _special_token_mask(self, seq):
337
- all_special_ids = set(
338
- self.all_special_ids) # call it once instead of inside list comp
339
- # all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
340
-
341
- return [1 if x in all_special_ids else 0 for x in seq]
342
-
343
- def get_special_tokens_mask(
344
- self,
345
- token_ids_0: List[int],
346
- token_ids_1: Optional[List[int]] = None,
347
- already_has_special_tokens: bool = False) -> List[int]:
348
- """
349
- Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
350
- special tokens using the tokenizer `prepare_for_model` method.
351
- Args:
352
- token_ids_0 (`List[int]`):
353
- List of IDs.
354
- token_ids_1 (`List[int]`, *optional*):
355
- Optional second list of IDs for sequence pairs.
356
- already_has_special_tokens (`bool`, *optional*, defaults to `False`):
357
- Whether or not the token list is already formatted with special tokens for the model.
358
- Returns:
359
- `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
360
- """
361
-
362
- if already_has_special_tokens:
363
- return self._special_token_mask(token_ids_0)
364
- elif token_ids_1 is None:
365
- return self._special_token_mask(token_ids_0) + [self.eos_token_id]
366
- else:
367
- return self._special_token_mask(token_ids_0 +
368
- token_ids_1) + [self.eos_token_id]
369
-
370
- def num_special_tokens_to_add(self, pair=False):
371
- """Just EOS"""
372
- return 1
373
-
374
- def save_vocabulary(self,
375
- save_directory: str,
376
- filename_prefix: Optional[str] = None) -> Tuple[str]:
377
- index = 0
378
- if os.path.isdir(save_directory):
379
- vocab_file = os.path.join(
380
- save_directory,
381
- (filename_prefix + "-" if filename_prefix else "") +
382
- VOCAB_FILES_NAMES["vocab_file"])
383
- else:
384
- vocab_file = (filename_prefix +
385
- "-" if filename_prefix else "") + save_directory
386
- with open(vocab_file, "w", encoding="utf-8") as writer:
387
- for token, token_index in sorted(self.vocab.items(),
388
- key=lambda kv: kv[1]):
389
- if index != token_index:
390
- logger.warning(
391
- f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
392
- " Please check that the vocabulary is not corrupted!")
393
- index = token_index
394
- writer.write(token + "\n")
395
- index += 1
396
- return (vocab_file, )
397
-
398
-
399
- class BasicTokenizer(object):
400
- """
401
- Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
402
- Args:
403
- do_lower_case (`bool`, *optional*, defaults to `True`):
404
- Whether or not to lowercase the input when tokenizing.
405
- never_split (`Iterable`, *optional*):
406
- Collection of tokens which will never be split during tokenization. Only has an effect when
407
- `do_basic_tokenize=True`
408
- tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
409
- Whether or not to tokenize Chinese characters.
410
- This should likely be deactivated for Japanese (see this
411
- [issue](https://github.com/huggingface/transformers/issues/328)).
412
- strip_accents: (`bool`, *optional*):
413
- Whether or not to strip all accents. If this option is not specified, then it will be determined by the
414
- value for `lowercase` (as in the original BERT).
415
- """
416
-
417
- def __init__(self,
418
- do_lower_case=True,
419
- never_split=None,
420
- tokenize_chinese_chars=True,
421
- strip_accents=None):
422
- if never_split is None:
423
- never_split = []
424
- self.do_lower_case = do_lower_case
425
- self.never_split = set(never_split)
426
- self.tokenize_chinese_chars = tokenize_chinese_chars
427
- self.strip_accents = strip_accents
428
-
429
- def tokenize(self, text, never_split=None):
430
- """
431
- Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
432
- WordPieceTokenizer.
433
- Args:
434
- never_split (`List[str]`, *optional*)
435
- Kept for backward compatibility purposes. Now implemented directly at the base class level (see
436
- [`PreTrainedTokenizer.tokenize`]) List of token not to split.
437
- """
438
- # union() returns a new set by concatenating the two sets.
439
- never_split = self.never_split.union(
440
- set(never_split)) if never_split else self.never_split
441
- text = self._clean_text(text)
442
-
443
- # This was added on November 1st, 2018 for the multilingual and Chinese
444
- # models. This is also applied to the English models now, but it doesn't
445
- # matter since the English models were not trained on any Chinese data
446
- # and generally don't have any Chinese data in them (there are Chinese
447
- # characters in the vocabulary because Wikipedia does have some Chinese
448
- # words in the English Wikipedia.).
449
- if self.tokenize_chinese_chars:
450
- text = self._tokenize_chinese_chars(text)
451
- orig_tokens = whitespace_tokenize(text)
452
- split_tokens = []
453
- for token in orig_tokens:
454
- if token not in never_split:
455
- if self.do_lower_case:
456
- token = token.lower()
457
- if self.strip_accents is not False:
458
- token = self._run_strip_accents(token)
459
- elif self.strip_accents:
460
- token = self._run_strip_accents(token)
461
- split_tokens.extend(self._run_split_on_punc(token, never_split))
462
-
463
- output_tokens = whitespace_tokenize(" ".join(split_tokens))
464
- return output_tokens
465
-
466
- def _run_strip_accents(self, text):
467
- """Strips accents from a piece of text."""
468
- text = unicodedata.normalize("NFD", text)
469
- output = []
470
- for char in text:
471
- cat = unicodedata.category(char)
472
- if cat == "Mn":
473
- continue
474
- output.append(char)
475
- return "".join(output)
476
-
477
- def _run_split_on_punc(self, text, never_split=None):
478
- """Splits punctuation on a piece of text."""
479
- if never_split is not None and text in never_split:
480
- return [text]
481
- chars = list(text)
482
- i = 0
483
- start_new_word = True
484
- output = []
485
- while i < len(chars):
486
- char = chars[i]
487
- if _is_punctuation(char):
488
- output.append([char])
489
- start_new_word = True
490
- else:
491
- if start_new_word:
492
- output.append([])
493
- start_new_word = False
494
- output[-1].append(char)
495
- i += 1
496
-
497
- return ["".join(x) for x in output]
498
-
499
- def _tokenize_chinese_chars(self, text):
500
- """Adds whitespace around any CJK character."""
501
- output = []
502
- for char in text:
503
- cp = ord(char)
504
- if self._is_chinese_char(cp):
505
- output.append(" ")
506
- output.append(char)
507
- output.append(" ")
508
- else:
509
- output.append(char)
510
- return "".join(output)
511
-
512
- def _is_chinese_char(self, cp):
513
- """Checks whether CP is the codepoint of a CJK character."""
514
- # This defines a "chinese character" as anything in the CJK Unicode block:
515
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
516
- #
517
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
518
- # despite its name. The modern Korean Hangul alphabet is a different block,
519
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
520
- # space-separated words, so they are not treated specially and handled
521
- # like the all of the other languages.
522
- if ((cp >= 0x4E00 and cp <= 0x9FFF)
523
- or (cp >= 0x3400 and cp <= 0x4DBF) #
524
- or (cp >= 0x20000 and cp <= 0x2A6DF) #
525
- or (cp >= 0x2A700 and cp <= 0x2B73F) #
526
- or (cp >= 0x2B740 and cp <= 0x2B81F) #
527
- or (cp >= 0x2B820 and cp <= 0x2CEAF) #
528
- or (cp >= 0xF900 and cp <= 0xFAFF)
529
- or (cp >= 0x2F800 and cp <= 0x2FA1F)): #
530
- return True
531
-
532
- return False
533
-
534
- def _clean_text(self, text):
535
- """Performs invalid character removal and whitespace cleanup on text."""
536
- output = []
537
- for char in text:
538
- cp = ord(char)
539
- if cp == 0 or cp == 0xFFFD or _is_control(char):
540
- continue
541
- if _is_whitespace(char):
542
- output.append(" ")
543
- else:
544
- output.append(char)
545
- return "".join(output)
546
-
547
-
548
- class WordpieceTokenizer(object):
549
- """Runs WordPiece tokenization."""
550
-
551
- def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
552
- self.vocab = vocab
553
- self.unk_token = unk_token
554
- self.max_input_chars_per_word = max_input_chars_per_word
555
-
556
- def tokenize(self, text):
557
- """
558
- Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
559
- tokenization using the given vocabulary.
560
- For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
561
- Args:
562
- text: A single token or whitespace separated tokens. This should have
563
- already been passed through *BasicTokenizer*.
564
- Returns:
565
- A list of wordpiece tokens.
566
- """
567
-
568
- output_tokens = []
569
- for token in whitespace_tokenize(text):
570
- chars = list(token)
571
- if len(chars) > self.max_input_chars_per_word:
572
- output_tokens.append(self.unk_token)
573
- continue
574
-
575
- is_bad = False
576
- start = 0
577
- sub_tokens = []
578
- while start < len(chars):
579
- end = len(chars)
580
- cur_substr = None
581
- while start < end:
582
- substr = "".join(chars[start:end])
583
- if start > 0:
584
- substr = "##" + substr
585
- if substr in self.vocab:
586
- cur_substr = substr
587
- break
588
- end -= 1
589
- if cur_substr is None:
590
- is_bad = True
591
- break
592
- sub_tokens.append(cur_substr)
593
- start = end
594
-
595
- if is_bad:
596
- output_tokens.append(self.unk_token)
597
- else:
598
- output_tokens.extend(sub_tokens)
599
- return output_tokens