fix: make `eos_token`/`pad_token` overridable and add `pickle` support
Browse files- tokenization_arcade100k.py +17 -3
tokenization_arcade100k.py
CHANGED
@@ -124,8 +124,12 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
|
124 |
|
125 |
self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
|
126 |
self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
|
127 |
-
|
128 |
-
self.
|
|
|
|
|
|
|
|
|
129 |
# Expose for convenience
|
130 |
self.mergeable_ranks = self.tokenizer._mergeable_ranks
|
131 |
self.special_tokens = self.tokenizer._special_tokens
|
@@ -133,6 +137,16 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
|
133 |
def __len__(self):
|
134 |
return self.tokenizer.n_vocab
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
@property
|
137 |
def vocab_size(self):
|
138 |
return self.tokenizer.n_vocab
|
@@ -273,4 +287,4 @@ class Arcade100kTokenizer(PreTrainedTokenizer):
|
|
273 |
token_ids = [token_ids]
|
274 |
if skip_special_tokens:
|
275 |
token_ids = [i for i in token_ids if i < self.tokenizer.eot_token]
|
276 |
-
return self.tokenizer.decode(token_ids)
|
|
|
124 |
|
125 |
self.decoder = {i: n for n, i in self.tokenizer._mergeable_ranks.items()}
|
126 |
self.decoder.update({i: n for n, i in self.tokenizer._special_tokens.items()})
|
127 |
+
# Provide default `eos_token` and `pad_token`
|
128 |
+
if self.eos_token is None:
|
129 |
+
self.eos_token = self.decoder[self.tokenizer.eot_token]
|
130 |
+
if self.pad_token is None:
|
131 |
+
self.pad_token = self.decoder[self.tokenizer.pad_token]
|
132 |
+
|
133 |
# Expose for convenience
|
134 |
self.mergeable_ranks = self.tokenizer._mergeable_ranks
|
135 |
self.special_tokens = self.tokenizer._special_tokens
|
|
|
137 |
def __len__(self):
|
138 |
return self.tokenizer.n_vocab
|
139 |
|
140 |
+
def __getstate__(self):
|
141 |
+
# Required for `pickle` support
|
142 |
+
state = self.__dict__.copy()
|
143 |
+
del state["tokenizer"]
|
144 |
+
return state
|
145 |
+
|
146 |
+
def __setstate__(self, state):
|
147 |
+
self.__dict__.update(state)
|
148 |
+
self.tokenizer = tiktoken.Encoding(**self._tiktoken_config)
|
149 |
+
|
150 |
@property
|
151 |
def vocab_size(self):
|
152 |
return self.tokenizer.n_vocab
|
|
|
287 |
token_ids = [token_ids]
|
288 |
if skip_special_tokens:
|
289 |
token_ids = [i for i in token_ids if i < self.tokenizer.eot_token]
|
290 |
+
return self.tokenizer.decode(token_ids)
|