File size: 3,621 Bytes
2bd98e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
# -*- coding: utf-8 -*-
"""
@author:cb
@contact:[email protected]
@time:2023/5/30 14:21
@filename:tokenization.py
@software:PyCharm
@description:
"""
import re
from transformers import FSMTTokenizer as fsmt
class FSMTTokenizer(fsmt):
space_re = re.compile('\s*(?=[^a-zA-Z0-9 ]+)\s*')
def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
return self.cache_moses_tokenizer[lang].tokenize(
text, aggressive_dash_splits=True, return_str=False, escape=False
)
def _switch_to_input_mode(self):
self.lang_prefix, self.lang_prefix_id = 'en', 64812
def _switch_to_target_mode(self):
self.lang_prefix, self.lang_prefix_id = 'zh', 64870
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A FAIRSEQ Transformer sequence has the following format:
- single sequence: `<s> X </s>`
- pair of sequences: `<s> A </s> B </s>`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
sep = [self.sep_token_id]
token_ids_0 = [self.lang_prefix_id] + token_ids_0
# no bos used in fairseq
if token_ids_1 is None:
return token_ids_0 + sep
return token_ids_0 + sep + token_ids_1 + sep
def moses_pipeline(self, text, lang):
text = self.moses_punct_norm(text, lang)
return text
def _tokenize(self, text, lang="en", bypass_tokenizer=False):
"""
原版FSMTTokenizer会把中文标点英文化,故重写
:param text:
:param lang:
:param bypass_tokenizer:
:return:
"""
if self.do_lower_case:
text = text.lower()
if bypass_tokenizer:
text = text.split()
else:
text = self.moses_pipeline(text, lang=self.lang_prefix)
text = self.moses_tokenize(text, lang=self.lang_prefix)
split_tokens = []
for token in text:
if token:
split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens
def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
"""
:param text:
:param is_split_into_words:
:param kwargs:
:return:
"""
if kwargs.get('src', True):
self._switch_to_input_mode()
else:
self._switch_to_target_mode()
return super(FSMTTokenizer, self).prepare_for_tokenization(text, is_split_into_words=False, **kwargs)
def convert_tokens_to_string(self, tokens):
"""
删除非英文字母前后的空格,业务上处理更合适
:param tokens:
:return:
"""
tokens = super(FSMTTokenizer, self).convert_tokens_to_string(tokens)
tokens = FSMTTokenizer.space_re.sub('', tokens)
return tokens
if __name__ == '__main__':
tokenizer = FSMTTokenizer.from_pretrained(r'./')
r = tokenizer.tokenize(['hello', 'hi'])
print(r) |