File size: 4,131 Bytes
071945c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
# coding=utf-8
# Copyright 2024 LY Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
import torch
from transformers import BatchEncoding, PreTrainedTokenizer, T5Tokenizer
from transformers.tokenization_utils_base import (
PaddingStrategy,
PreTokenizedInput,
TextInput,
TruncationStrategy,
)
class CLYPTokenizer(PreTrainedTokenizer):
"""CLYPTokenizer based on rinna/japanese-roberta-base
This tokenizer is registered as a custom tokenizer to manually add CLS token to each text.
"""
def __init__(self, max_length: int, padding: str, truncation: bool, **kwargs):
# tokenizer
self.tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
self.tokenizer.do_lower_case = True
super().__init__(
max_length=max_length, padding=padding, truncation=truncation, **kwargs
)
self.max_length = max_length
self.padding = padding
self.truncation = truncation
@property
def vocab_size(self):
return self.tokenizer.vocab_size
def get_vocab(self) -> dict[str, int]:
return self.tokenizer.get_vocab()
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> tuple[str]:
return self.tokenizer.save_vocabulary(
save_directory, filename_prefix=filename_prefix
)
def _tokenize(self, text, **kwargs):
return self.tokenizer._tokenize(text, **kwargs)
def _convert_token_to_id(self, token):
return self.tokenizer._convert_token_to_id(token)
def _convert_id_to_token(self, index: int) -> str:
return self.tokenizer._convert_id_to_token(index)
def __call__(
self,
text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
add_special_tokens: bool = True,
padding: bool | str | PaddingStrategy | None = None,
truncation: bool | str | TruncationStrategy | None = None,
max_length: Optional[int] = None,
**kwargs,
):
if max_length is None:
max_length = self.max_length
if padding is None:
padding = self.padding
if truncation is None:
truncation = self.truncation
if add_special_tokens:
max_length = max_length - 1
if not isinstance(text, list):
# TODO: Review
text = [text]
out = self.tokenizer(
text,
max_length=max_length,
padding=padding,
truncation=truncation,
add_special_tokens=False,
**kwargs,
)
if add_special_tokens:
input_ids = [
[self.tokenizer.cls_token_id] + ids for ids in out["input_ids"]
]
attention_mask = [[1] + am for am in out["attention_mask"]]
position_ids = [list(range(0, len(input_ids[0])))] * len(input_ids)
else:
input_ids = out["input_ids"]
attention_mask = out["attention_mask"]
position_ids = [list(range(0, len(input_ids[0])))] * len(input_ids)
# tensor
input_ids = torch.tensor(input_ids, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
position_ids = torch.tensor(position_ids, dtype=torch.long)
# retrn
data = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
return BatchEncoding(data=data)
|