|
import os |
|
from typing import List, Union |
|
|
|
import tensorflow as tf |
|
from tensorflow_text import BertTokenizer as BertTokenizerLayer |
|
from tensorflow_text import FastBertTokenizer, ShrinkLongestTrimmer, case_fold_utf8, combine_segments, pad_model_inputs |
|
|
|
from .tokenization_bert import BertTokenizer |
|
|
|
|
|
class TFBertTokenizer(tf.keras.layers.Layer): |
|
""" |
|
This is an in-graph tokenizer for BERT. It should be initialized similarly to other tokenizers, using the |
|
`from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings |
|
from an existing standard tokenizer object. |
|
|
|
In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run |
|
when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options |
|
than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes |
|
straight from `tf.string` inputs to outputs. |
|
|
|
Args: |
|
vocab_list (`list`): |
|
List containing the vocabulary. |
|
do_lower_case (`bool`, *optional*, defaults to `True`): |
|
Whether or not to lowercase the input when tokenizing. |
|
cls_token_id (`str`, *optional*, defaults to `"[CLS]"`): |
|
The classifier token which is used when doing sequence classification (classification of the whole sequence |
|
instead of per-token classification). It is the first token of the sequence when built with special tokens. |
|
sep_token_id (`str`, *optional*, defaults to `"[SEP]"`): |
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for |
|
sequence classification or for a text and a question for question answering. It is also used as the last |
|
token of a sequence built with special tokens. |
|
pad_token_id (`str`, *optional*, defaults to `"[PAD]"`): |
|
The token used for padding, for example when batching sequences of different lengths. |
|
padding (`str`, defaults to `"longest"`): |
|
The type of padding to use. Can be either `"longest"`, to pad only up to the longest sample in the batch, |
|
or `"max_length", to pad all inputs to the maximum length supported by the tokenizer. |
|
truncation (`bool`, *optional*, defaults to `True`): |
|
Whether to truncate the sequence to the maximum length. |
|
max_length (`int`, *optional*, defaults to `512`): |
|
The maximum length of the sequence, used for padding (if `padding` is "max_length") and/or truncation (if |
|
`truncation` is `True`). |
|
pad_to_multiple_of (`int`, *optional*, defaults to `None`): |
|
If set, the sequence will be padded to a multiple of this value. |
|
return_token_type_ids (`bool`, *optional*, defaults to `True`): |
|
Whether to return token_type_ids. |
|
return_attention_mask (`bool`, *optional*, defaults to `True`): |
|
Whether to return the attention_mask. |
|
use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`): |
|
If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer |
|
class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to |
|
TFLite. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vocab_list: List, |
|
do_lower_case: bool, |
|
cls_token_id: int = None, |
|
sep_token_id: int = None, |
|
pad_token_id: int = None, |
|
padding: str = "longest", |
|
truncation: bool = True, |
|
max_length: int = 512, |
|
pad_to_multiple_of: int = None, |
|
return_token_type_ids: bool = True, |
|
return_attention_mask: bool = True, |
|
use_fast_bert_tokenizer: bool = True, |
|
**tokenizer_kwargs, |
|
): |
|
super().__init__() |
|
if use_fast_bert_tokenizer: |
|
self.tf_tokenizer = FastBertTokenizer( |
|
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs |
|
) |
|
else: |
|
lookup_table = tf.lookup.StaticVocabularyTable( |
|
tf.lookup.KeyValueTensorInitializer( |
|
keys=vocab_list, |
|
key_dtype=tf.string, |
|
values=tf.range(tf.size(vocab_list, out_type=tf.int64), dtype=tf.int64), |
|
value_dtype=tf.int64, |
|
), |
|
num_oov_buckets=1, |
|
) |
|
self.tf_tokenizer = BertTokenizerLayer( |
|
lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs |
|
) |
|
|
|
self.vocab_list = vocab_list |
|
self.do_lower_case = do_lower_case |
|
self.cls_token_id = cls_token_id or vocab_list.index("[CLS]") |
|
self.sep_token_id = sep_token_id or vocab_list.index("[SEP]") |
|
self.pad_token_id = pad_token_id or vocab_list.index("[PAD]") |
|
self.paired_trimmer = ShrinkLongestTrimmer(max_length - 3, axis=1) |
|
self.max_length = max_length |
|
self.padding = padding |
|
self.truncation = truncation |
|
self.pad_to_multiple_of = pad_to_multiple_of |
|
self.return_token_type_ids = return_token_type_ids |
|
self.return_attention_mask = return_attention_mask |
|
|
|
@classmethod |
|
def from_tokenizer(cls, tokenizer: "PreTrainedTokenizerBase", **kwargs): |
|
""" |
|
Initialize a `TFBertTokenizer` from an existing `Tokenizer`. |
|
|
|
Args: |
|
tokenizer (`PreTrainedTokenizerBase`): |
|
The tokenizer to use to initialize the `TFBertTokenizer`. |
|
|
|
Examples: |
|
|
|
```python |
|
from transformers import AutoTokenizer, TFBertTokenizer |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer) |
|
``` |
|
""" |
|
do_lower_case = kwargs.pop("do_lower_case", None) |
|
do_lower_case = tokenizer.do_lower_case if do_lower_case is None else do_lower_case |
|
cls_token_id = kwargs.pop("cls_token_id", None) |
|
cls_token_id = tokenizer.cls_token_id if cls_token_id is None else cls_token_id |
|
sep_token_id = kwargs.pop("sep_token_id", None) |
|
sep_token_id = tokenizer.sep_token_id if sep_token_id is None else sep_token_id |
|
pad_token_id = kwargs.pop("pad_token_id", None) |
|
pad_token_id = tokenizer.pad_token_id if pad_token_id is None else pad_token_id |
|
|
|
vocab = tokenizer.get_vocab() |
|
vocab = sorted(vocab.items(), key=lambda x: x[1]) |
|
vocab_list = [entry[0] for entry in vocab] |
|
return cls( |
|
vocab_list=vocab_list, |
|
do_lower_case=do_lower_case, |
|
cls_token_id=cls_token_id, |
|
sep_token_id=sep_token_id, |
|
pad_token_id=pad_token_id, |
|
**kwargs, |
|
) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): |
|
""" |
|
Instantiate a `TFBertTokenizer` from a pre-trained tokenizer. |
|
|
|
Args: |
|
pretrained_model_name_or_path (`str` or `os.PathLike`): |
|
The name or path to the pre-trained tokenizer. |
|
|
|
Examples: |
|
|
|
```python |
|
from transformers import TFBertTokenizer |
|
|
|
tf_tokenizer = TFBertTokenizer.from_pretrained("bert-base-uncased") |
|
``` |
|
""" |
|
try: |
|
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) |
|
except: |
|
from .tokenization_bert_fast import BertTokenizerFast |
|
|
|
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) |
|
return cls.from_tokenizer(tokenizer, **kwargs) |
|
|
|
def unpaired_tokenize(self, texts): |
|
if self.do_lower_case: |
|
texts = case_fold_utf8(texts) |
|
tokens = self.tf_tokenizer.tokenize(texts) |
|
return tokens.merge_dims(1, -1) |
|
|
|
def call( |
|
self, |
|
text, |
|
text_pair=None, |
|
padding=None, |
|
truncation=None, |
|
max_length=None, |
|
pad_to_multiple_of=None, |
|
return_token_type_ids=None, |
|
return_attention_mask=None, |
|
): |
|
if padding is None: |
|
padding = self.padding |
|
if padding not in ("longest", "max_length"): |
|
raise ValueError("Padding must be either 'longest' or 'max_length'!") |
|
if max_length is not None and text_pair is not None: |
|
|
|
raise ValueError("max_length cannot be overridden at call time when truncating paired texts!") |
|
if max_length is None: |
|
max_length = self.max_length |
|
if truncation is None: |
|
truncation = self.truncation |
|
if pad_to_multiple_of is None: |
|
pad_to_multiple_of = self.pad_to_multiple_of |
|
if return_token_type_ids is None: |
|
return_token_type_ids = self.return_token_type_ids |
|
if return_attention_mask is None: |
|
return_attention_mask = self.return_attention_mask |
|
if not isinstance(text, tf.Tensor): |
|
text = tf.convert_to_tensor(text) |
|
if text_pair is not None and not isinstance(text_pair, tf.Tensor): |
|
text_pair = tf.convert_to_tensor(text_pair) |
|
if text_pair is not None: |
|
if text.shape.rank > 1: |
|
raise ValueError("text argument should not be multidimensional when a text pair is supplied!") |
|
if text_pair.shape.rank > 1: |
|
raise ValueError("text_pair should not be multidimensional!") |
|
if text.shape.rank == 2: |
|
text, text_pair = text[:, 0], text[:, 1] |
|
text = self.unpaired_tokenize(text) |
|
if text_pair is None: |
|
if truncation: |
|
text = text[:, : max_length - 2] |
|
input_ids, token_type_ids = combine_segments( |
|
(text,), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id |
|
) |
|
else: |
|
text_pair = self.unpaired_tokenize(text_pair) |
|
if truncation: |
|
text, text_pair = self.paired_trimmer.trim([text, text_pair]) |
|
input_ids, token_type_ids = combine_segments( |
|
(text, text_pair), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id |
|
) |
|
if padding == "longest": |
|
pad_length = input_ids.bounding_shape(axis=1) |
|
if pad_to_multiple_of is not None: |
|
|
|
pad_length = pad_to_multiple_of * (-tf.math.floordiv(-pad_length, pad_to_multiple_of)) |
|
else: |
|
pad_length = max_length |
|
|
|
input_ids, attention_mask = pad_model_inputs(input_ids, max_seq_length=pad_length, pad_value=self.pad_token_id) |
|
output = {"input_ids": input_ids} |
|
if return_attention_mask: |
|
output["attention_mask"] = attention_mask |
|
if return_token_type_ids: |
|
token_type_ids, _ = pad_model_inputs( |
|
token_type_ids, max_seq_length=pad_length, pad_value=self.pad_token_id |
|
) |
|
output["token_type_ids"] = token_type_ids |
|
return output |
|
|
|
def get_config(self): |
|
return { |
|
"vocab_list": self.vocab_list, |
|
"do_lower_case": self.do_lower_case, |
|
"cls_token_id": self.cls_token_id, |
|
"sep_token_id": self.sep_token_id, |
|
"pad_token_id": self.pad_token_id, |
|
} |
|
|