robvanderg commited on
Commit
83bc30a
1 Parent(s): 9bcdbd3

Upload 9 files

Browse files

Upload model contents

all_results.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"perplexity": 3.3834318944245285}
config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/flan-t5-base",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "classifier_dropout": 0.0,
7
+ "d_ff": 2048,
8
+ "d_kv": 64,
9
+ "d_model": 768,
10
+ "decoder_start_token_id": 0,
11
+ "dense_act_fn": "gelu_new",
12
+ "dropout_rate": 0.1,
13
+ "eos_token_id": 1,
14
+ "feed_forward_proj": "gated-gelu",
15
+ "initializer_factor": 1.0,
16
+ "is_encoder_decoder": true,
17
+ "is_gated_act": true,
18
+ "layer_norm_epsilon": 1e-06,
19
+ "model_type": "t5",
20
+ "n_positions": 512,
21
+ "num_decoder_layers": 12,
22
+ "num_heads": 12,
23
+ "num_layers": 12,
24
+ "output_past": true,
25
+ "pad_token_id": 0,
26
+ "relative_attention_max_distance": 128,
27
+ "relative_attention_num_buckets": 32,
28
+ "task_specific_params": {
29
+ "summarization": {
30
+ "early_stopping": true,
31
+ "length_penalty": 2.0,
32
+ "max_length": 200,
33
+ "min_length": 30,
34
+ "no_repeat_ngram_size": 3,
35
+ "num_beams": 4,
36
+ "prefix": "summarize: "
37
+ },
38
+ "translation_en_to_de": {
39
+ "early_stopping": true,
40
+ "max_length": 300,
41
+ "num_beams": 4,
42
+ "prefix": "translate English to German: "
43
+ },
44
+ "translation_en_to_fr": {
45
+ "early_stopping": true,
46
+ "max_length": 300,
47
+ "num_beams": 4,
48
+ "prefix": "translate English to French: "
49
+ },
50
+ "translation_en_to_ro": {
51
+ "early_stopping": true,
52
+ "max_length": 300,
53
+ "num_beams": 4,
54
+ "prefix": "translate English to Romanian: "
55
+ }
56
+ },
57
+ "tie_word_embeddings": false,
58
+ "torch_dtype": "float32",
59
+ "transformers_version": "4.33.1",
60
+ "use_cache": true,
61
+ "vocab_size": 32128
62
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.33.1"
7
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3196a90a28b15607731c72c47b8a7a4f925894dc01a676f6040e1ad9310d0338
3
+ size 990408885
run_t5_mlm_torch.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a mix of:
2
+ # https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py
3
+ # https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
4
+ # Since there seems to be no way to retrain T5 without Flax
5
+ # Done by Rob van der Goot (09-2023): [email protected]
6
+
7
+ # Biggest TODO is probably dynamic masking; in the current version the data is only
8
+ # prepared once. I also broke the tensorboard functionality (see comments).
9
+
10
+ import json
11
+ import logging
12
+ import math
13
+ import os
14
+ import sys
15
+ import time
16
+ import warnings
17
+ from dataclasses import asdict, dataclass, field
18
+
19
+ from enum import Enum
20
+ from itertools import chain
21
+ from pathlib import Path
22
+ from typing import Dict, List, Optional
23
+
24
+ from datasets import load_dataset
25
+ from huggingface_hub import Repository, create_repo
26
+ from tqdm import tqdm
27
+ from torch.utils.data import DataLoader
28
+
29
+ from accelerate import Accelerator, DistributedType
30
+ from accelerate.logging import get_logger
31
+ from accelerate.utils import set_seed
32
+ from transformers import (
33
+ CONFIG_MAPPING,
34
+ MODEL_FOR_MASKED_LM_MAPPING,
35
+ AutoTokenizer,
36
+ BatchEncoding,
37
+ T5ForConditionalGeneration,
38
+ HfArgumentParser,
39
+ PreTrainedTokenizerBase,
40
+ T5Config,
41
+ is_tensorboard_available,
42
+ set_seed,
43
+ )
44
+ from transformers.utils import send_example_telemetry
45
+ from transformers import AutoModel, get_linear_schedule_with_warmup
46
+ import torch
47
+ torch.manual_seed(8446)
48
+
49
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
50
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
51
+
52
+ # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
53
+ def shift_tokens_right(input_ids , pad_token_id: int, decoder_start_token_id: int) :
54
+ """
55
+ Shift input ids one token to the right.
56
+ """
57
+ shifted_input_ids = torch.zeros(input_ids.shape, dtype=input_ids.dtype)
58
+ #input_ids = torch.tensor(input_ids)
59
+ shifted_input_ids[:,1:] = input_ids[:,:-1]
60
+ shifted_input_ids[:,0] = decoder_start_token_id
61
+ #shifted_input_ids = jnp.zeros_like(input_ids)
62
+ #shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
63
+ #shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
64
+ shifted_input_ids[shifted_input_ids==-100] = pad_token_id
65
+ #shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
66
+ return shifted_input_ids
67
+
68
+ @dataclass
69
+ class TrainingArguments:
70
+ output_dir: str = field(
71
+ metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
72
+ )
73
+ overwrite_output_dir: bool = field(
74
+ default=False,
75
+ metadata={
76
+ "help": (
77
+ "Overwrite the content of the output directory. "
78
+ "Use this to continue training if output_dir points to a checkpoint directory."
79
+ )
80
+ },
81
+ )
82
+ do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
83
+ do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
84
+ per_device_train_batch_size: int = field(
85
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
86
+ )
87
+ per_device_eval_batch_size: int = field(
88
+ default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
89
+ )
90
+ learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
91
+ weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
92
+ adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
93
+ adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
94
+ adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
95
+ adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
96
+ num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
97
+ warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
98
+ logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
99
+ save_steps: str = field(default=None, metadata={"help": "Save checkpoint every X updates steps."})
100
+ eval_steps: int = field(default=100, metadata={"help": "Run an evaluation every X steps."})
101
+ seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
102
+ push_to_hub: bool = field(
103
+ default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
104
+ )
105
+ hub_model_id: str = field(
106
+ default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
107
+ )
108
+ hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
109
+
110
+ def __post_init__(self):
111
+ if self.output_dir is not None:
112
+ self.output_dir = os.path.expanduser(self.output_dir)
113
+
114
+ def to_dict(self):
115
+ """
116
+ Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
117
+ the token values by removing their value.
118
+ """
119
+ d = asdict(self)
120
+ for k, v in d.items():
121
+ if isinstance(v, Enum):
122
+ d[k] = v.value
123
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
124
+ d[k] = [x.value for x in v]
125
+ if k.endswith("_token"):
126
+ d[k] = f"<{k.upper()}>"
127
+ return d
128
+
129
+ @dataclass
130
+ class ModelArguments:
131
+ """
132
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
133
+ """
134
+
135
+ model_name_or_path: Optional[str] = field(
136
+ default=None,
137
+ metadata={
138
+ "help": (
139
+ "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
140
+ )
141
+ },
142
+ )
143
+ model_type: Optional[str] = field(
144
+ default=None,
145
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
146
+ )
147
+ config_name: Optional[str] = field(
148
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
149
+ )
150
+ tokenizer_name: Optional[str] = field(
151
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
152
+ )
153
+ cache_dir: Optional[str] = field(
154
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
155
+ )
156
+ use_fast_tokenizer: bool = field(
157
+ default=True,
158
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
159
+ )
160
+ dtype: Optional[str] = field(
161
+ default="float32",
162
+ metadata={
163
+ "help": (
164
+ "Floating-point format in which the model weights should be initialized and trained. Choose one of"
165
+ " `[float32, float16, bfloat16]`."
166
+ )
167
+ },
168
+ )
169
+ token: str = field(
170
+ default=None,
171
+ metadata={
172
+ "help": (
173
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
174
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
175
+ )
176
+ },
177
+ )
178
+ use_auth_token: bool = field(
179
+ default=None,
180
+ metadata={
181
+ "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
182
+ },
183
+ )
184
+
185
+ @dataclass
186
+ class DataTrainingArguments:
187
+ """
188
+ Arguments pertaining to what data we are going to input our model for training and eval.
189
+ """
190
+
191
+ dataset_name: Optional[str] = field(
192
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
193
+ )
194
+ dataset_config_name: Optional[str] = field(
195
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
196
+ )
197
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
198
+ validation_file: Optional[str] = field(
199
+ default=None,
200
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
201
+ )
202
+ train_ref_file: Optional[str] = field(
203
+ default=None,
204
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
205
+ )
206
+ validation_ref_file: Optional[str] = field(
207
+ default=None,
208
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
209
+ )
210
+ overwrite_cache: bool = field(
211
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
212
+ )
213
+ validation_split_percentage: Optional[int] = field(
214
+ default=5,
215
+ metadata={
216
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
217
+ },
218
+ )
219
+ max_seq_length: Optional[int] = field(
220
+ default=None,
221
+ metadata={
222
+ "help": (
223
+ "The maximum total input sequence length after tokenization and masking. Sequences longer than this"
224
+ " will be truncated. Default to the max input length of the model."
225
+ )
226
+ },
227
+ )
228
+ preprocessing_num_workers: Optional[int] = field(
229
+ default=None,
230
+ metadata={"help": "The number of processes to use for the preprocessing."},
231
+ )
232
+ mlm_probability: float = field(
233
+ default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
234
+ )
235
+ mean_noise_span_length: float = field(
236
+ default=3.0,
237
+ metadata={"help": "Mean span length of masked tokens"},
238
+ )
239
+
240
+ def __post_init__(self):
241
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
242
+ raise ValueError("Need either a dataset name or a training/validation file.")
243
+ else:
244
+ if self.train_file is not None:
245
+ extension = self.train_file.split(".")[-1]
246
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
247
+ if self.validation_file is not None:
248
+ extension = self.validation_file.split(".")[-1]
249
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
250
+
251
+
252
+ def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
253
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
254
+
255
+ Training parameters to avoid padding with random_spans_noise_mask.
256
+ When training a model with random_spans_noise_mask, we would like to set the other
257
+ training hyperparmeters in a way that avoids padding.
258
+ This function helps us compute these hyperparameters.
259
+ We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
260
+ and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
261
+ This function tells us the required number of tokens in the raw example (for split_tokens())
262
+ as well as the length of the encoded targets. Note that this function assumes
263
+ the inputs and targets will have EOS appended and includes that in the reported length.
264
+
265
+ Args:
266
+ inputs_length: an integer - desired length of the tokenized inputs sequence
267
+ noise_density: a float
268
+ mean_noise_span_length: a float
269
+ Returns:
270
+ tokens_length: length of original text in tokens
271
+ targets_length: an integer - length in tokens of encoded targets sequence
272
+ """
273
+
274
+ def _tokens_length_to_inputs_length_targets_length(tokens_length):
275
+ num_noise_tokens = int(round(tokens_length * noise_density))
276
+ num_nonnoise_tokens = tokens_length - num_noise_tokens
277
+ num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
278
+ # inputs contain all nonnoise tokens, sentinels for all noise spans
279
+ # and one EOS token.
280
+ _input_length = num_nonnoise_tokens + num_noise_spans + 1
281
+ _output_length = num_noise_tokens + num_noise_spans + 1
282
+ return _input_length, _output_length
283
+
284
+ tokens_length = inputs_length
285
+
286
+ while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
287
+ tokens_length += 1
288
+
289
+ inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
290
+
291
+ # minor hack to get the targets length to be equal to inputs length
292
+ # which is more likely to have been set to a nice round number.
293
+ if noise_density == 0.5 and targets_length > inputs_length:
294
+ tokens_length -= 1
295
+ targets_length -= 1
296
+ return tokens_length, targets_length
297
+
298
+
299
+ class DataCollatorForT5MLM:
300
+ """
301
+ Data collator used for T5 span-masked language modeling.
302
+ It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
303
+ For more information on how T5 span-masked language modeling works, one can take a look
304
+ at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
305
+ or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
306
+
307
+ Args:
308
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
309
+ The tokenizer used for encoding the data.
310
+ noise_density (:obj:`float`):
311
+ The probability with which to (randomly) mask tokens in the input.
312
+ mean_noise_span_length (:obj:`float`):
313
+ The average span length of the masked tokens.
314
+ input_length (:obj:`int`):
315
+ The expected input length after masking.
316
+ target_length (:obj:`int`):
317
+ The expected target length after masking.
318
+ pad_token_id: (:obj:`int`):
319
+ The pad token id of the model
320
+ decoder_start_token_id: (:obj:`int):
321
+ The decoder start token id of the model
322
+ """
323
+ def __init__(self,
324
+ tokenizer: PreTrainedTokenizerBase,
325
+ noise_density: float,
326
+ mean_noise_span_length: float,
327
+ input_length: int,
328
+ target_length: int,
329
+ pad_token_id: int,
330
+ decoder_start_token_id: int):
331
+
332
+ self.tokenizer = tokenizer
333
+ self.noise_density = noise_density
334
+ self.mean_noise_span_length = mean_noise_span_length
335
+ self.input_length = input_length
336
+ self.target_length = target_length
337
+ self.pad_token_id = pad_token_id
338
+ self.decoder_start_token_id = decoder_start_token_id
339
+
340
+
341
+ def __call__(self, examples: List[Dict[str, list]]) -> BatchEncoding:
342
+ # convert list to dict and tensorize input
343
+ input_ids = [examples[i]['input_ids'] for i in range(len(examples))]
344
+ max_len = max([len(x) for x in input_ids])
345
+ # could definitely be done neater
346
+ for rowIdx in range(len(input_ids)):
347
+ while len(input_ids[rowIdx]) != max_len:
348
+ input_ids[rowIdx].append(self.pad_token_id)
349
+ batch1 = {'input_ids': input_ids}
350
+ batch1['input_ids'] = torch.tensor(batch1['input_ids'])
351
+ batch = BatchEncoding(batch1)
352
+ #{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
353
+ #)
354
+
355
+ input_ids = batch["input_ids"]
356
+ batch_size, expandend_input_length = input_ids.shape
357
+
358
+ mask_indices = torch.stack([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)])
359
+ labels_mask = ~mask_indices
360
+
361
+ input_ids_sentinel = self.create_sentinel_ids(mask_indices)
362
+ labels_sentinel = self.create_sentinel_ids(labels_mask)
363
+
364
+ batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
365
+ batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
366
+
367
+ self.input_length
368
+ if batch["input_ids"].shape[-1] != self.input_length:
369
+ raise ValueError(
370
+ f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but"
371
+ f" should be {self.input_length}."
372
+ )
373
+
374
+ if batch["labels"].shape[-1] != self.target_length:
375
+ raise ValueError(
376
+ f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be"
377
+ f" {self.target_length}."
378
+ )
379
+
380
+ # to check that tokens are correctly preprocessed, one can run `self.tokenizer.batch_decode(input_ids)` and `self.tokenizer.batch_decode(labels)` here...
381
+ batch["decoder_input_ids"] = shift_tokens_right(
382
+ batch["labels"], self.pad_token_id, self.decoder_start_token_id
383
+ )
384
+ return batch
385
+
386
+ def create_sentinel_ids(self, mask_indices):
387
+ """
388
+ Sentinel ids creation given the indices that should be masked.
389
+ The start indices of each mask are replaced by the sentinel ids in increasing
390
+ order. Consecutive mask indices to be deleted are replaced with `-1`.
391
+ """
392
+ mask_indices = mask_indices.type(torch.int8)
393
+ start_indices = mask_indices - torch.roll(mask_indices, 1, dims=-1) * mask_indices
394
+ start_indices[:, 0] = mask_indices[:, 0]
395
+
396
+ #sentinel_ids = start_indices
397
+ sentinel_ids = torch.where(start_indices!=0, torch.cumsum(start_indices, dim=-1), start_indices)
398
+ #sentinel_ids[start_indices != 0] = torch.cumsum(start_indices, dim=-1)#, start_indices)
399
+ sentinel_ids = torch.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
400
+
401
+ sentinel_ids -= mask_indices - start_indices
402
+
403
+ #sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
404
+ #sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
405
+ return sentinel_ids
406
+
407
+ def filter_input_ids(self, input_ids, sentinel_ids):
408
+ """
409
+ Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
410
+ This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
411
+ """
412
+ batch_size = input_ids.shape[0]
413
+
414
+ input_ids_full = torch.where(sentinel_ids != 0, sentinel_ids, input_ids)
415
+ # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
416
+ # masked tokens coming after sentinel tokens and should be removed
417
+ input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
418
+ input_ids = torch.concat(
419
+ [input_ids, torch.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=torch.int32)], dim=-1
420
+ )
421
+ return input_ids
422
+
423
+ def random_spans_noise_mask(self, length):
424
+ """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
425
+
426
+ Noise mask consisting of random spans of noise tokens.
427
+ The number of noise tokens and the number of noise spans and non-noise spans
428
+ are determined deterministically as follows:
429
+ num_noise_tokens = round(length * noise_density)
430
+ num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
431
+ Spans alternate between non-noise and noise, beginning with non-noise.
432
+ Subject to the above restrictions, all masks are equally likely.
433
+
434
+ Args:
435
+ length: an int32 scalar (length of the incoming token sequence)
436
+ noise_density: a float - approximate density of output mask
437
+ mean_noise_span_length: a number
438
+
439
+ Returns:
440
+ a boolean tensor with shape [length]
441
+ """
442
+
443
+ orig_length = length
444
+
445
+
446
+ num_noise_tokens = round(length * self.noise_density)
447
+ num_nonnoise_tokens = length - num_noise_tokens
448
+ # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
449
+ num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
450
+ # num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens
451
+ num_noise_spans = round(min(num_noise_tokens, num_nonnoise_tokens) / self.mean_noise_span_length)
452
+
453
+ # avoid degeneracy by ensuring positive number of noise spans
454
+ num_noise_spans = max(num_noise_spans, 1)
455
+
456
+ # pick the lengths of the noise spans and the non-noise spans
457
+ def _random_segmentation(num_items, num_segments):
458
+ """Partition a sequence of items randomly into non-empty segments.
459
+ Args:
460
+ num_items: an integer scalar > 0
461
+ num_segments: an integer scalar in [1, num_items]
462
+ Returns:
463
+ a Tensor with shape [num_segments] containing positive integers that add
464
+ up to num_items
465
+ """
466
+ mask_indices = torch.arange(num_items - 1) < (num_segments - 1)
467
+ # https://discuss.pytorch.org/t/shuffling-a-tensor/25422/3
468
+ #np.random.shuffle(mask_indices)
469
+ idx = torch.randperm(mask_indices.nelement())
470
+ mask_indices = mask_indices.view(-1)[idx].view(mask_indices.size())
471
+
472
+ first_in_segment = torch.cat([torch.tensor([False]), mask_indices])
473
+ segment_id = torch.cumsum(first_in_segment, dim=0)
474
+ # count length of sub segments assuming that list is sorted
475
+ _, segment_length = torch.unique(segment_id, return_counts=True)
476
+ return segment_length
477
+
478
+ noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
479
+ nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
480
+
481
+ interleaved_span_lengths = torch.reshape(
482
+ torch.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
483
+ )
484
+ span_starts = torch.cumsum(interleaved_span_lengths, dim=0)[:-1]
485
+ span_start_indicator = torch.zeros((length,), dtype=torch.int8)
486
+ span_start_indicator[span_starts] = True
487
+ span_num = torch.cumsum(span_start_indicator, dim=0)
488
+ is_noise = span_num % 2 == 1
489
+
490
+ return is_noise[:orig_length]
491
+
492
+
493
+ def generate_batch_splits(samples_idx: list, batch_size: int, drop_last=True) -> list:
494
+ """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
495
+ the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
496
+ num_samples = len(samples_idx)
497
+ if drop_last:
498
+ samples_to_remove = num_samples % batch_size
499
+ if samples_to_remove != 0:
500
+ samples_idx = samples_idx[:-samples_to_remove]
501
+ sections_split = num_samples // batch_size
502
+ samples_idx = samples_idx.reshape((sections_split, batch_size))
503
+ else:
504
+ sections_split = math.ceil(num_samples / batch_size)
505
+ samples_idx = torch.split(samples_idx, sections_split)
506
+ return samples_idx
507
+
508
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
509
+ summary_writer.scalar("train_time", train_time, step)
510
+
511
+ train_metrics = get_metrics(train_metrics)
512
+ for key, vals in train_metrics.items():
513
+ tag = f"train_{key}"
514
+ for i, val in enumerate(vals):
515
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
516
+
517
+
518
+ def write_eval_metric(summary_writer, eval_metrics, step):
519
+ for metric_name, value in eval_metrics.items():
520
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
521
+
522
+
523
+ def main():
524
+ # See all possible arguments in src/transformers/training_args.py
525
+ # or by passing the --help flag to this script.
526
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
527
+
528
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
529
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
530
+ # If we pass only one argument to the script and it's the path to a json file,
531
+ # let's parse it to get our arguments.
532
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
533
+ else:
534
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
535
+
536
+ accelerator = Accelerator()
537
+
538
+ if model_args.use_auth_token is not None:
539
+ warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
540
+ if model_args.token is not None:
541
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
542
+ model_args.token = model_args.use_auth_token
543
+
544
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
545
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
546
+ #send_example_telemetry("run_t5_mlm", model_args, data_args, framework="flax")
547
+
548
+ if (
549
+ os.path.exists(training_args.output_dir)
550
+ and os.listdir(training_args.output_dir)
551
+ and training_args.do_train
552
+ and not training_args.overwrite_output_dir
553
+ ):
554
+ raise ValueError(
555
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
556
+ "Use --overwrite_output_dir to overcome."
557
+ )
558
+
559
+ # Setup logging
560
+ logging.basicConfig(
561
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
562
+ level=logging.INFO,
563
+ datefmt="[%X]",
564
+ )
565
+
566
+ # Log on each process the small summary:
567
+ logger = logging.getLogger(__name__)
568
+
569
+ # Set the verbosity to info of the Transformers logger (on main process only):
570
+ logger.info(f"Training/evaluation parameters {training_args}")
571
+
572
+ # Set seed before initializing model.
573
+ set_seed(training_args.seed)
574
+
575
+ # Handle the repository creation
576
+ if training_args.push_to_hub:
577
+ # Retrieve of infer repo_name
578
+ repo_name = training_args.hub_model_id
579
+ if repo_name is None:
580
+ repo_name = Path(training_args.output_dir).absolute().name
581
+ # Create repo and retrieve repo_id
582
+ repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
583
+ # Clone repo locally
584
+ repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
585
+
586
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
587
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
588
+ # (the dataset will be downloaded automatically from the datasets Hub).
589
+ #
590
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
591
+ # 'text' is found. You can easily tweak this behavior (see below).
592
+ if data_args.dataset_name is not None:
593
+ # Downloading and loading a dataset from the hub.
594
+ datasets = load_dataset(
595
+ data_args.dataset_name,
596
+ data_args.dataset_config_name,
597
+ cache_dir=model_args.cache_dir,
598
+ token=model_args.token,
599
+ )
600
+
601
+ if "validation" not in datasets.keys():
602
+ datasets["validation"] = load_dataset(
603
+ data_args.dataset_name,
604
+ data_args.dataset_config_name,
605
+ split=f"train[:{data_args.validation_split_percentage}%]",
606
+ cache_dir=model_args.cache_dir,
607
+ token=model_args.token,
608
+ )
609
+ datasets["train"] = load_dataset(
610
+ data_args.dataset_name,
611
+ data_args.dataset_config_name,
612
+ split=f"train[{data_args.validation_split_percentage}%:]",
613
+ cache_dir=model_args.cache_dir,
614
+ token=model_args.token,
615
+ )
616
+ else:
617
+ data_files = {}
618
+ if data_args.train_file is not None:
619
+ data_files["train"] = data_args.train_file
620
+ if data_args.validation_file is not None:
621
+ data_files["validation"] = data_args.validation_file
622
+ extension = data_args.train_file.split(".")[-1]
623
+ if extension == "txt":
624
+ extension = "text"
625
+
626
+ datasets = load_dataset(
627
+ extension,
628
+ data_files=data_files,
629
+ cache_dir=model_args.cache_dir,
630
+ token=model_args.token,
631
+ )
632
+
633
+ if "validation" not in datasets.keys():
634
+ datasets["validation"] = load_dataset(
635
+ extension,
636
+ data_files=data_files,
637
+ split=f"train[:{data_args.validation_split_percentage}%]",
638
+ cache_dir=model_args.cache_dir,
639
+ token=model_args.token,
640
+ )
641
+ datasets["train"] = load_dataset(
642
+ extension,
643
+ data_files=data_files,
644
+ split=f"train[{data_args.validation_split_percentage}%:]",
645
+ cache_dir=model_args.cache_dir,
646
+ token=model_args.token,
647
+ )
648
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
649
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
650
+
651
+ # Load pretrained model and tokenizer
652
+
653
+ if model_args.tokenizer_name:
654
+ tokenizer = AutoTokenizer.from_pretrained(
655
+ model_args.tokenizer_name,
656
+ cache_dir=model_args.cache_dir,
657
+ use_fast=model_args.use_fast_tokenizer,
658
+ token=model_args.token,
659
+ )
660
+ elif model_args.model_name_or_path:
661
+ tokenizer = AutoTokenizer.from_pretrained(
662
+ model_args.model_name_or_path,
663
+ cache_dir=model_args.cache_dir,
664
+ use_fast=model_args.use_fast_tokenizer,
665
+ token=model_args.token,
666
+ )
667
+ else:
668
+ raise ValueError(
669
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
670
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
671
+ )
672
+
673
+ if model_args.config_name:
674
+ config = T5Config.from_pretrained(
675
+ model_args.config_name,
676
+ cache_dir=model_args.cache_dir,
677
+ vocab_size=len(tokenizer),
678
+ token=model_args.token,
679
+ )
680
+ elif model_args.model_name_or_path:
681
+ config = T5Config.from_pretrained(
682
+ model_args.model_name_or_path,
683
+ cache_dir=model_args.cache_dir,
684
+ token=model_args.token,
685
+ )
686
+ else:
687
+ config = CONFIG_MAPPING[model_args.model_type]()
688
+ logger.warning("You are instantiating a new config instance from scratch.")
689
+
690
+ # Preprocessing the datasets.
691
+ # First we tokenize all the texts.
692
+ if training_args.do_train:
693
+ column_names = datasets["train"].column_names
694
+ else:
695
+ column_names = datasets["validation"].column_names
696
+ text_column_name = "text" if "text" in column_names else column_names[0]
697
+
698
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
699
+
700
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
701
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
702
+ def tokenize_function(examples):
703
+ return tokenizer(examples[text_column_name], return_attention_mask=False)
704
+
705
+ tokenized_datasets = datasets.map(
706
+ tokenize_function,
707
+ batched=True,
708
+ num_proc=data_args.preprocessing_num_workers,
709
+ remove_columns=column_names,
710
+ load_from_cache_file=not data_args.overwrite_cache,
711
+ )
712
+
713
+ # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
714
+ # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
715
+ # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
716
+ expanded_inputs_length, targets_length = compute_input_and_target_lengths(
717
+ inputs_length=max_seq_length,
718
+ noise_density=data_args.mlm_probability,
719
+ mean_noise_span_length=data_args.mean_noise_span_length,
720
+ )
721
+
722
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
723
+ def group_texts(examples):
724
+ # Concatenate all texts.
725
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
726
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
727
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
728
+ # customize this part to your needs.
729
+ if total_length >= expanded_inputs_length:
730
+ total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
731
+ # Split by chunks of max_len.
732
+ result = {
733
+ k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
734
+ for k, t in concatenated_examples.items()
735
+ }
736
+ return result
737
+
738
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
739
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
740
+ # might be slower to preprocess.
741
+ #
742
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
743
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
744
+ tokenized_datasets = tokenized_datasets.map(
745
+ group_texts,
746
+ batched=True,
747
+ num_proc=data_args.preprocessing_num_workers,
748
+ load_from_cache_file=not data_args.overwrite_cache,
749
+ )
750
+
751
+ # Enable tensorboard only on the master node
752
+ has_tensorboard = is_tensorboard_available()
753
+ #if has_tensorboard and jax.process_index() == 0:
754
+ # try:
755
+ # from flax.metrics.tensorboard import SummaryWriter
756
+ #
757
+ # summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
758
+ # except ImportError as ie:
759
+ # has_tensorboard = False
760
+ # logger.warning(
761
+ # f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
762
+ # )
763
+ # else:
764
+ # logger.warning(
765
+ # "Unable to display metrics through TensorBoard because the package is not installed: "
766
+ # "Please run pip install tensorboard to enable."
767
+ # )
768
+
769
+ if model_args.model_name_or_path:
770
+ model = T5ForConditionalGeneration.from_pretrained(
771
+ model_args.model_name_or_path,
772
+ config=config,
773
+ #seed=training_args.seed,
774
+ token=model_args.token,
775
+ )
776
+ else:
777
+ config.vocab_size = len(tokenizer)
778
+ model = T5ForConditionalGeneration(
779
+ config,
780
+ seed=training_args.seed,
781
+ )
782
+
783
+ # Data collator
784
+ # This one will take care of randomly masking the tokens.
785
+ data_collator = DataCollatorForT5MLM(
786
+ tokenizer=tokenizer,
787
+ noise_density=data_args.mlm_probability,
788
+ mean_noise_span_length=data_args.mean_noise_span_length,
789
+ input_length=max_seq_length,
790
+ target_length=targets_length,
791
+ pad_token_id=model.config.pad_token_id,
792
+ decoder_start_token_id=model.config.decoder_start_token_id,
793
+ )
794
+
795
+ train_dataset = tokenized_datasets["train"]
796
+ eval_dataset = tokenized_datasets["validation"]
797
+ train_dataloader = DataLoader(
798
+ train_dataset, shuffle=True, collate_fn=data_collator, batch_size=training_args.per_device_train_batch_size
799
+ )
800
+ eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=training_args.per_device_eval_batch_size)
801
+
802
+ # Store some constant
803
+ num_epochs = int(training_args.num_train_epochs)
804
+ train_batch_size = int(training_args.per_device_train_batch_size) #* jax.device_count()
805
+ per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
806
+ eval_batch_size = per_device_eval_batch_size #* jax.device_count()
807
+
808
+ num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
809
+
810
+
811
+ # adam optimizer
812
+ no_decay = ["bias", "LayerNorm.weight"]
813
+ optimizer_grouped_parameters = [
814
+ {
815
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
816
+ "weight_decay": training_args.weight_decay,
817
+ },
818
+ {
819
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
820
+ "weight_decay": 0.0,
821
+ },
822
+ ]
823
+ optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=training_args.learning_rate, betas=(training_args.adam_beta1, training_args.adam_beta2), eps=training_args.adam_epsilon)
824
+
825
+ # scheduler
826
+ lr_scheduler = get_linear_schedule_with_warmup(
827
+ optimizer=optimizer,
828
+ num_warmup_steps= training_args.warmup_steps, #* args.gradient_accumulation_steps,
829
+ num_training_steps=num_train_steps
830
+ )
831
+
832
+ # Prepare everything with our `accelerator`.
833
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
834
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
835
+ )
836
+
837
+ # On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
838
+ if accelerator.distributed_type == DistributedType.TPU:
839
+ model.tie_weights()
840
+
841
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
842
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))# / args.gradient_accumulation_steps)
843
+
844
+ #?
845
+ #num_train_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch)
846
+
847
+ # Figure out how many steps we should save the Accelerator states
848
+ checkpointing_steps = training_args.save_steps
849
+ if checkpointing_steps is not None and checkpointing_steps.isdigit():
850
+ checkpointing_steps = int(checkpointing_steps)
851
+ # Train!
852
+ total_batch_size = training_args.per_device_train_batch_size * accelerator.num_processes #* args.gradient_accumulation_steps
853
+ # Only show the progress bar once on each machine.
854
+ progress_bar = tqdm(range(num_train_steps), disable=not accelerator.is_local_main_process)
855
+ completed_steps = 0
856
+ starting_epoch = 0
857
+
858
+ #print(training_args.num_train_epochs)
859
+ for epoch in range(starting_epoch, int(training_args.num_train_epochs)):
860
+ model.train()
861
+ active_dataloader = train_dataloader
862
+ for step, batch in enumerate(active_dataloader):
863
+ with accelerator.accumulate(model):
864
+ outputs = model(**batch)
865
+ loss = outputs.loss
866
+ # We keep track of the loss at each epoch
867
+
868
+ accelerator.backward(loss)
869
+ optimizer.step()
870
+ lr_scheduler.step()
871
+ optimizer.zero_grad()
872
+
873
+ # Checks if the accelerator has performed an optimization step behind the scenes
874
+ if accelerator.sync_gradients:
875
+ progress_bar.update(1)
876
+ completed_steps += 1
877
+
878
+ if isinstance(checkpointing_steps, int):
879
+ if completed_steps % checkpointing_steps == 0:
880
+ output_dir = f"step_{completed_steps }"
881
+ if training_args.output_dir is not None:
882
+ output_dir = os.path.join(training_args.output_dir, output_dir)
883
+ accelerator.save_state(output_dir)
884
+
885
+ if completed_steps >= num_train_steps:
886
+ break
887
+
888
+ if step % training_args.eval_steps == 0 and step > 0:
889
+ model.eval()
890
+ losses = []
891
+ for dev_step, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating ...", position=2)):
892
+ with torch.no_grad():
893
+ outputs = model(**batch)
894
+
895
+ loss = outputs.loss
896
+ losses.append(accelerator.gather_for_metrics(loss.repeat(training_args.per_device_eval_batch_size)))
897
+
898
+ losses = torch.cat(losses)
899
+ try:
900
+ eval_loss = torch.mean(losses)
901
+ perplexity = math.exp(eval_loss)
902
+ except OverflowError:
903
+ perplexity = float("inf")
904
+
905
+ logger.info(f"step {step}: perplexity: {perplexity}")
906
+
907
+ model.eval()
908
+ losses = []
909
+ for step, batch in enumerate(eval_dataloader):
910
+ with torch.no_grad():
911
+ outputs = model(**batch)
912
+
913
+ loss = outputs.loss
914
+ losses.append(accelerator.gather_for_metrics(loss.repeat(training_args.per_device_eval_batch_size)))
915
+
916
+ losses = torch.cat(losses)
917
+ try:
918
+ eval_loss = torch.mean(losses)
919
+ perplexity = math.exp(eval_loss)
920
+ except OverflowError:
921
+ perplexity = float("inf")
922
+
923
+ logger.info(f"epoch {epoch}: perplexity: {perplexity}")
924
+
925
+ if training_args.push_to_hub and epoch < training_args.num_train_epochs - 1:
926
+ accelerator.wait_for_everyone()
927
+ unwrapped_model = accelerator.unwrap_model(model)
928
+ unwrapped_model.save_pretrained(
929
+ training_args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
930
+ )
931
+ if accelerator.is_main_process:
932
+ tokenizer.save_pretrained(training_args.output_dir)
933
+ repo.push_to_hub(
934
+ commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
935
+ )
936
+
937
+ if training_args.save_steps == "epoch":
938
+ output_dir = f"epoch_{epoch}"
939
+ if training_args.output_dir is not None:
940
+ output_dir = os.path.join(training_args.output_dir, output_dir)
941
+ accelerator.save_state(output_dir)
942
+
943
+ if training_args.output_dir is not None:
944
+ accelerator.wait_for_everyone()
945
+ unwrapped_model = accelerator.unwrap_model(model)
946
+ unwrapped_model.save_pretrained(
947
+ training_args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
948
+ )
949
+ if accelerator.is_main_process:
950
+ tokenizer.save_pretrained(training_args.output_dir)
951
+ if training_args.push_to_hub:
952
+ repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
953
+
954
+ with open(os.path.join(training_args.output_dir, "all_results.json"), "w") as f:
955
+ json.dump({"perplexity": perplexity}, f)
956
+
957
+
958
+
959
+
960
+ if __name__ == "__main__":
961
+ main()
962
+
special_tokens_map.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": "</s>",
105
+ "pad_token": "<pad>",
106
+ "unk_token": "<unk>"
107
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "clean_up_tokenization_spaces": true,
105
+ "eos_token": "</s>",
106
+ "extra_ids": 100,
107
+ "model_max_length": 512,
108
+ "pad_token": "<pad>",
109
+ "sp_model_kwargs": {},
110
+ "tokenizer_class": "T5Tokenizer",
111
+ "unk_token": "<unk>"
112
+ }