File size: 4,164 Bytes
c5bff01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import List

from transformers import (
    LlamaTokenizer,
    LlamaTokenizerFast,
    SLOW_TO_FAST_CONVERTERS,
)

from tokenizers import decoders, normalizers
from transformers.convert_slow_tokenizer import LlamaConverter

SPIECE_UNDERLINE = "▁"

class FixLlamaTokenizer(LlamaTokenizer):

    def __init__(self, *args, **kwargs):
        kwargs['legacy'] = False
        super().__init__(*args, **kwargs)

    def tokenize(self, text, **kwargs) -> List[int]:
        
        return super().tokenize(text.replace(SPIECE_UNDERLINE, " "), **kwargs)

    def _tokenize(self, text, **kwargs):
        return self.sp_model.encode(text, out_type=str)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""

        current_sub_tokens = []
        out_string = ""
        for _, token in enumerate(tokens):
            # make sure that special tokens are not decoded using sentencepiece model
            if token in self.all_special_tokens:
                out_string += self.sp_model.decode(current_sub_tokens) + token
                current_sub_tokens = []
            else:
                current_sub_tokens.append(token)
        out_string += self.sp_model.decode(current_sub_tokens)
        return out_string

    def _decode(
        self,
        token_ids: List[int],
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: bool = None,
        spaces_between_special_tokens: bool = False,
        **kwargs,
    ) -> str:
        self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)

        filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
        legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | {
            token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size
        }
        # To avoid mixing byte-level and unicode for byte-level BPT
        # we need to build string separately for added tokens and byte-level tokens
        # cf. https://github.com/huggingface/transformers/issues/1133
        sub_texts = []
        current_sub_text = []
        # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string
        for token in filtered_tokens:
            if skip_special_tokens and token in self.all_special_ids:
                continue
            if token in legacy_added_tokens:
                if current_sub_text:
                    string = self.convert_tokens_to_string(current_sub_text)
                    if len(string) > 0:
                        sub_texts.append(string)
                    current_sub_text = []
                sub_texts.append(token)
            else:
                current_sub_text.append(token)
        if current_sub_text:
            sub_texts.append(self.convert_tokens_to_string(current_sub_text))

        if spaces_between_special_tokens:
            text = " ".join(sub_texts)
        else:
            text = "".join(sub_texts)

        clean_up_tokenization_spaces = (
            clean_up_tokenization_spaces
            if clean_up_tokenization_spaces is not None
            else self.clean_up_tokenization_spaces
        )
        if clean_up_tokenization_spaces:
            clean_text = self.clean_up_tokenization(text)
            return clean_text
        else:
            return text

class FixLlamaTokenizerFast(LlamaTokenizerFast):
    slow_tokenizer_class = FixLlamaTokenizer

class FixLlamaTokenizerConverter(LlamaConverter):
    def normalizer(self, proto):
        return normalizers.Replace(pattern=' ', content='▁')

    def decoder(self, replacement, add_prefix_space):
        return decoders.Sequence(
            [
                decoders.Replace('▁', ' '),
                decoders.ByteFallback(),
                decoders.Fuse(),
            ]
        )

FixLlamaTokenizer.register_for_auto_class()
FixLlamaTokenizerFast.register_for_auto_class()
SLOW_TO_FAST_CONVERTERS[FixLlamaTokenizer.__name__] = FixLlamaTokenizerConverter