File size: 5,255 Bytes
21d29cb 7cfca48 21d29cb 7cfca48 21d29cb 7cfca48 21d29cb 7cfca48 21d29cb 7cfca48 21d29cb 7cfca48 21d29cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer
from transformers import (
HfArgumentParser,
)
from data_utils import (
filter_by_lang_regex,
filter_by_num_tokens,
filter_by_num_sents,
normalizer
)
logger = logging.getLogger(__name__)
@dataclass
class TokenizerArguments:
"""
Arguments to which tokenizer we are going to set up.
"""
output_dir: str = field(
default=".",
metadata={"help": "The output directory where the config will be written."},
)
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
special_tokens: Optional[List[str]] = field(
default=None,
metadata={"help": "The list of special tokens that you want to add in your training."}
)
vocab_size: Optional[int] = field(
default=50257,
metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
)
min_frequency: Optional[int] = field(
default=2,
metadata={"help": "The minimum frequency a pair should have in order to be merged"}
)
show_progress: Optional[bool] = field(
default=True,
metadata={"help": "Whether to show progress bars while training"}
)
def __post_init__(self):
if self.special_tokens is None:
self.special_tokens = [
"<s>", "<pad>", "</s>", "<unk>", "<mask>",
"<|endoftext|>", "<|startoftext|>",
"<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
]
self.special_tokens = self.special_tokens + [f"[U{i}]" for i in range(1, 21)]
if self.dataset_name is None and self.train_file is None:
raise ValueError("Need either a dataset name or a training file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
def main():
parser = HfArgumentParser([TokenizerArguments])
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
tokenizer_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
else:
tokenizer_args = parser.parse_args_into_dataclasses()[0]
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger.setLevel(logging.INFO)
logger.info(f"Training tokenizer")
if tokenizer_args.dataset_name is not None:
raw_dataset = load_dataset(
tokenizer_args.dataset_name,
tokenizer_args.dataset_config_name,
cache_dir=tokenizer_args.cache_dir,
split="train[:10%]"
)
else:
data_files = {"train": tokenizer_args.train_file}
extension = tokenizer_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
raw_dataset = load_dataset(
extension,
data_files=data_files,
delimiter="\t",
cache_dir=tokenizer_args.cache_dir,
)
logger.info("Preprocessing the dataset")
dataset = raw_dataset.filter(lambda example: filter_by_lang_regex(example["text"], ratio=0.75))
dataset = dataset.filter(lambda example: filter_by_num_tokens(example["text"], gt=64))
dataset = dataset.filter(lambda example: filter_by_num_sents(example["text"], gt=2))
dataset = dataset.map(normalizer)
logger.info(f"Preprocessed dataset kept {len(dataset)} out of {len(raw_dataset)}")
tokenizer = ByteLevelBPETokenizer()
def batch_iterative(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
tokenizer.train_from_iterator(
batch_iterative(),
vocab_size=tokenizer_args.vocab_size,
special_tokens=tokenizer_args.special_tokens,
min_frequency=tokenizer_args.min_frequency,
show_progress=tokenizer_args.show_progress,
)
logger.info(f"Your tokenizer saved here {tokenizer_args.output_dir}")
os.makedirs(tokenizer_args.output_dir, exist_ok=True)
tokenizer.save_model(tokenizer_args.output_dir)
tokenizer.save(f"{tokenizer_args.output_dir}/tokenizer.json", pretty=True)
if __name__ == '__main__':
main()
|