File size: 4,680 Bytes
21d29cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import ast
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer
from transformers import (
    HfArgumentParser,
)

logger = logging.getLogger(__name__)


@dataclass
class TokenizerArguments:
    """
    Arguments to which tokenizer we are going to set up.
    """

    output_dir: str = field(
        default=".",
        metadata={"help": "The output directory where the config will be written."},
    )
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    special_tokens: Optional[List[str]] = field(
        default=None,
        metadata={"help": "The list of special tokens that you want to add in your training."}
    )
    vocab_size: Optional[int] = field(
        default=50257,
        metadata={"help": "The size of the final vocabulary, including all tokens and alphabet"}
    )
    min_frequency: Optional[int] = field(
        default=2,
        metadata={"help": "The minimum frequency a pair should have in order to be merged"}
    )
    show_progress: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether to show progress bars while training"}
    )

    def __post_init__(self):
        if self.special_tokens is None:
            self.special_tokens = [
                "<s>", "<pad>", "</s>", "<unk>", "<mask>",
                "<|endoftext|>", "<|startoftext|>",
                "<sep>", "<cls>", "<nl>", "<tab>", "<zwnj>"
            ]

        self.special_tokens = self.special_tokens + [f"[U{i}]" for i in range(1, 21)]
        if self.dataset_name is None and self.train_file is None:
            raise ValueError("Need either a dataset name or a training file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."


def main():
    parser = HfArgumentParser([TokenizerArguments])
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        tokenizer_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
    else:
        tokenizer_args = parser.parse_args_into_dataclasses()[0]

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger.setLevel(logging.INFO)

    logger.info(f"Training tokenizer")

    if tokenizer_args.dataset_name is not None:
        dataset = load_dataset(
            tokenizer_args.dataset_name,
            tokenizer_args.dataset_config_name,
            cache_dir=tokenizer_args.cache_dir,
            split="train"
        )
    else:
        data_files = {"train": tokenizer_args.train_file}
        extension = tokenizer_args.train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"

        dataset = load_dataset(
            extension,
            data_files=data_files,
            delimiter="\t",
            cache_dir=tokenizer_args.cache_dir,
        )

    tokenizer = ByteLevelBPETokenizer()

    def batch_iterative(batch_size=1000):
        for i in range(0, len(dataset), batch_size):
            yield dataset[i: i + batch_size]["text"]

    tokenizer.train_from_iterator(
        batch_iterative(),
        vocab_size=tokenizer_args.vocab_size,
        special_tokens=tokenizer_args.special_tokens,
        min_frequency=tokenizer_args.min_frequency,
        show_progress=tokenizer_args.show_progress,
    )

    logger.info(f"Your tokenizer saved here {tokenizer_args.output_dir}/tokenizer")
    os.makedirs(tokenizer_args.output_dir, exist_ok=True)
    tokenizer.save_model(tokenizer_args.output_dir)
    tokenizer.save(f"{tokenizer_args.output_dir}/tokenizer.json", pretty=True)


if __name__ == '__main__':
    main()