plamo-13b / tokenization_plamo.py
dhigurashi
support transformers==4.34.0
b0f6037
raw
history blame
5.93 kB
from __future__ import annotations
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
logger = logging.get_logger(__name__)
class PlamoTokenizer(PreTrainedTokenizer): # type: ignore
vocab_files_names = VOCAB_FILES_NAMES
model_input_names = ["input_ids", "attention_mask"]
def __init__(
self,
vocab_file: str,
unk_token: str = "<unk>",
bos_token: str = "<s>",
eos_token: str = "</s>",
pad_token: str = "<pad>",
cls_token: str = "<cls>",
sep_token: str = "<sep>",
mask_token: str = "<mask>",
sp_model_kwargs: Optional[Dict[str, Any]] = None,
clean_up_tokenization_spaces: bool = False,
**kwargs: Any,
) -> None:
if "add_bos_token" not in kwargs:
kwargs["add_bos_token"] = False
if "add_eos_token" not in kwargs:
kwargs["add_eos_token"] = False
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
self.vocab_file = vocab_file
self.add_bos_token = kwargs["add_bos_token"]
self.add_eos_token = kwargs["add_eos_token"]
super().__init__(
vocab_file=vocab_file,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
cls_token=cls_token,
sep_token=sep_token,
mask_token=mask_token,
sp_model_kwargs=sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
# the functions below are copied from hf transformers LlamaTokenizer's implementation to fix the behaviour of the tokenizer
# https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/models/llama/tokenization_llama.py
def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
state["sp_model"] = None
state["sp_model_proto"] = self.sp_model.serialized_model_proto()
return state
def __setstate__(self, d: dict[str, Any]) -> None:
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@property
def vocab_size(self) -> Any:
"""Returns vocab size"""
return self.sp_model.get_piece_size()
def get_vocab(self) -> dict[str, int]:
"""Returns vocab as a dict"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def convert_tokens_to_string(self, tokens: List[int]) -> str:
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens: List[int] = []
out_string = ""
prev_is_special = False
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special and i != 0:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
def _tokenize(self, text: str) -> Any:
"""Returns a tokenized string."""
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token: str) -> Any:
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index: int) -> Any:
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return ("",)
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)