|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import math |
|
import os |
|
import sys |
|
import time |
|
import warnings |
|
from dataclasses import asdict, dataclass, field |
|
|
|
from enum import Enum |
|
from itertools import chain |
|
from pathlib import Path |
|
from typing import Dict, List, Optional |
|
|
|
from datasets import load_dataset |
|
from huggingface_hub import Repository, create_repo |
|
from tqdm import tqdm |
|
from torch.utils.data import DataLoader |
|
|
|
from accelerate import Accelerator, DistributedType |
|
from accelerate.logging import get_logger |
|
from accelerate.utils import set_seed |
|
from transformers import ( |
|
CONFIG_MAPPING, |
|
MODEL_FOR_MASKED_LM_MAPPING, |
|
AutoTokenizer, |
|
BatchEncoding, |
|
T5ForConditionalGeneration, |
|
HfArgumentParser, |
|
PreTrainedTokenizerBase, |
|
T5Config, |
|
is_tensorboard_available, |
|
set_seed, |
|
) |
|
from transformers.utils import send_example_telemetry |
|
from transformers import AutoModel, get_linear_schedule_with_warmup |
|
import torch |
|
torch.manual_seed(8446) |
|
|
|
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) |
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
|
|
|
|
|
def shift_tokens_right(input_ids , pad_token_id: int, decoder_start_token_id: int) : |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = torch.zeros(input_ids.shape, dtype=input_ids.dtype) |
|
|
|
shifted_input_ids[:,1:] = input_ids[:,:-1] |
|
shifted_input_ids[:,0] = decoder_start_token_id |
|
|
|
|
|
|
|
shifted_input_ids[shifted_input_ids==-100] = pad_token_id |
|
|
|
return shifted_input_ids |
|
|
|
@dataclass |
|
class TrainingArguments: |
|
output_dir: str = field( |
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, |
|
) |
|
overwrite_output_dir: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": ( |
|
"Overwrite the content of the output directory. " |
|
"Use this to continue training if output_dir points to a checkpoint directory." |
|
) |
|
}, |
|
) |
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) |
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) |
|
per_device_train_batch_size: int = field( |
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."} |
|
) |
|
per_device_eval_batch_size: int = field( |
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."} |
|
) |
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) |
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) |
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) |
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) |
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) |
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."}) |
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."}) |
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) |
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) |
|
save_steps: str = field(default=None, metadata={"help": "Save checkpoint every X updates steps."}) |
|
eval_steps: int = field(default=100, metadata={"help": "Run an evaluation every X steps."}) |
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) |
|
push_to_hub: bool = field( |
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} |
|
) |
|
hub_model_id: str = field( |
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} |
|
) |
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) |
|
|
|
def __post_init__(self): |
|
if self.output_dir is not None: |
|
self.output_dir = os.path.expanduser(self.output_dir) |
|
|
|
def to_dict(self): |
|
""" |
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates |
|
the token values by removing their value. |
|
""" |
|
d = asdict(self) |
|
for k, v in d.items(): |
|
if isinstance(v, Enum): |
|
d[k] = v.value |
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): |
|
d[k] = [x.value for x in v] |
|
if k.endswith("_token"): |
|
d[k] = f"<{k.upper()}>" |
|
return d |
|
|
|
@dataclass |
|
class ModelArguments: |
|
""" |
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
|
""" |
|
|
|
model_name_or_path: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." |
|
) |
|
}, |
|
) |
|
model_type: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, |
|
) |
|
config_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
|
) |
|
tokenizer_name: Optional[str] = field( |
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
|
) |
|
cache_dir: Optional[str] = field( |
|
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} |
|
) |
|
use_fast_tokenizer: bool = field( |
|
default=True, |
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
|
) |
|
dtype: Optional[str] = field( |
|
default="float32", |
|
metadata={ |
|
"help": ( |
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of" |
|
" `[float32, float16, bfloat16]`." |
|
) |
|
}, |
|
) |
|
token: str = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " |
|
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)." |
|
) |
|
}, |
|
) |
|
use_auth_token: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." |
|
}, |
|
) |
|
|
|
@dataclass |
|
class DataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
""" |
|
|
|
dataset_name: Optional[str] = field( |
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
|
) |
|
dataset_config_name: Optional[str] = field( |
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
|
) |
|
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) |
|
validation_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, |
|
) |
|
train_ref_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."}, |
|
) |
|
validation_ref_file: Optional[str] = field( |
|
default=None, |
|
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."}, |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
|
) |
|
validation_split_percentage: Optional[int] = field( |
|
default=5, |
|
metadata={ |
|
"help": "The percentage of the train set used as validation set in case there's no validation split" |
|
}, |
|
) |
|
max_seq_length: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
"The maximum total input sequence length after tokenization and masking. Sequences longer than this" |
|
" will be truncated. Default to the max input length of the model." |
|
) |
|
}, |
|
) |
|
preprocessing_num_workers: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "The number of processes to use for the preprocessing."}, |
|
) |
|
mlm_probability: float = field( |
|
default=0.15, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"} |
|
) |
|
mean_noise_span_length: float = field( |
|
default=3.0, |
|
metadata={"help": "Mean span length of masked tokens"}, |
|
) |
|
|
|
def __post_init__(self): |
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
|
raise ValueError("Need either a dataset name or a training/validation file.") |
|
else: |
|
if self.train_file is not None: |
|
extension = self.train_file.split(".")[-1] |
|
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." |
|
if self.validation_file is not None: |
|
extension = self.validation_file.split(".")[-1] |
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." |
|
|
|
|
|
def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length): |
|
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ . |
|
|
|
Training parameters to avoid padding with random_spans_noise_mask. |
|
When training a model with random_spans_noise_mask, we would like to set the other |
|
training hyperparmeters in a way that avoids padding. |
|
This function helps us compute these hyperparameters. |
|
We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, |
|
and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. |
|
This function tells us the required number of tokens in the raw example (for split_tokens()) |
|
as well as the length of the encoded targets. Note that this function assumes |
|
the inputs and targets will have EOS appended and includes that in the reported length. |
|
|
|
Args: |
|
inputs_length: an integer - desired length of the tokenized inputs sequence |
|
noise_density: a float |
|
mean_noise_span_length: a float |
|
Returns: |
|
tokens_length: length of original text in tokens |
|
targets_length: an integer - length in tokens of encoded targets sequence |
|
""" |
|
|
|
def _tokens_length_to_inputs_length_targets_length(tokens_length): |
|
num_noise_tokens = int(round(tokens_length * noise_density)) |
|
num_nonnoise_tokens = tokens_length - num_noise_tokens |
|
num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) |
|
|
|
|
|
_input_length = num_nonnoise_tokens + num_noise_spans + 1 |
|
_output_length = num_noise_tokens + num_noise_spans + 1 |
|
return _input_length, _output_length |
|
|
|
tokens_length = inputs_length |
|
|
|
while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length: |
|
tokens_length += 1 |
|
|
|
inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length) |
|
|
|
|
|
|
|
if noise_density == 0.5 and targets_length > inputs_length: |
|
tokens_length -= 1 |
|
targets_length -= 1 |
|
return tokens_length, targets_length |
|
|
|
|
|
class DataCollatorForT5MLM: |
|
""" |
|
Data collator used for T5 span-masked language modeling. |
|
It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length. |
|
For more information on how T5 span-masked language modeling works, one can take a look |
|
at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__ |
|
or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ . |
|
|
|
Args: |
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`): |
|
The tokenizer used for encoding the data. |
|
noise_density (:obj:`float`): |
|
The probability with which to (randomly) mask tokens in the input. |
|
mean_noise_span_length (:obj:`float`): |
|
The average span length of the masked tokens. |
|
input_length (:obj:`int`): |
|
The expected input length after masking. |
|
target_length (:obj:`int`): |
|
The expected target length after masking. |
|
pad_token_id: (:obj:`int`): |
|
The pad token id of the model |
|
decoder_start_token_id: (:obj:`int): |
|
The decoder start token id of the model |
|
""" |
|
def __init__(self, |
|
tokenizer: PreTrainedTokenizerBase, |
|
noise_density: float, |
|
mean_noise_span_length: float, |
|
input_length: int, |
|
target_length: int, |
|
pad_token_id: int, |
|
decoder_start_token_id: int): |
|
|
|
self.tokenizer = tokenizer |
|
self.noise_density = noise_density |
|
self.mean_noise_span_length = mean_noise_span_length |
|
self.input_length = input_length |
|
self.target_length = target_length |
|
self.pad_token_id = pad_token_id |
|
self.decoder_start_token_id = decoder_start_token_id |
|
|
|
|
|
def __call__(self, examples: List[Dict[str, list]]) -> BatchEncoding: |
|
|
|
input_ids = [examples[i]['input_ids'] for i in range(len(examples))] |
|
max_len = max([len(x) for x in input_ids]) |
|
|
|
for rowIdx in range(len(input_ids)): |
|
while len(input_ids[rowIdx]) != max_len: |
|
input_ids[rowIdx].append(self.pad_token_id) |
|
batch1 = {'input_ids': input_ids} |
|
batch1['input_ids'] = torch.tensor(batch1['input_ids']) |
|
batch = BatchEncoding(batch1) |
|
|
|
|
|
|
|
input_ids = batch["input_ids"] |
|
batch_size, expandend_input_length = input_ids.shape |
|
|
|
mask_indices = torch.stack([self.random_spans_noise_mask(expandend_input_length) for i in range(batch_size)]) |
|
labels_mask = ~mask_indices |
|
|
|
input_ids_sentinel = self.create_sentinel_ids(mask_indices) |
|
labels_sentinel = self.create_sentinel_ids(labels_mask) |
|
|
|
batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel) |
|
batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel) |
|
|
|
self.input_length |
|
if batch["input_ids"].shape[-1] != self.input_length: |
|
raise ValueError( |
|
f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but" |
|
f" should be {self.input_length}." |
|
) |
|
|
|
if batch["labels"].shape[-1] != self.target_length: |
|
raise ValueError( |
|
f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be" |
|
f" {self.target_length}." |
|
) |
|
|
|
|
|
batch["decoder_input_ids"] = shift_tokens_right( |
|
batch["labels"], self.pad_token_id, self.decoder_start_token_id |
|
) |
|
return batch |
|
|
|
def create_sentinel_ids(self, mask_indices): |
|
""" |
|
Sentinel ids creation given the indices that should be masked. |
|
The start indices of each mask are replaced by the sentinel ids in increasing |
|
order. Consecutive mask indices to be deleted are replaced with `-1`. |
|
""" |
|
mask_indices = mask_indices.type(torch.int8) |
|
start_indices = mask_indices - torch.roll(mask_indices, 1, dims=-1) * mask_indices |
|
start_indices[:, 0] = mask_indices[:, 0] |
|
|
|
|
|
sentinel_ids = torch.where(start_indices!=0, torch.cumsum(start_indices, dim=-1), start_indices) |
|
|
|
sentinel_ids = torch.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0) |
|
|
|
sentinel_ids -= mask_indices - start_indices |
|
|
|
|
|
|
|
return sentinel_ids |
|
|
|
def filter_input_ids(self, input_ids, sentinel_ids): |
|
""" |
|
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. |
|
This will reduce the sequence length from `expanded_inputs_length` to `input_length`. |
|
""" |
|
batch_size = input_ids.shape[0] |
|
|
|
input_ids_full = torch.where(sentinel_ids != 0, sentinel_ids, input_ids) |
|
|
|
|
|
input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1)) |
|
input_ids = torch.concat( |
|
[input_ids, torch.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=torch.int32)], dim=-1 |
|
) |
|
return input_ids |
|
|
|
def random_spans_noise_mask(self, length): |
|
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . |
|
|
|
Noise mask consisting of random spans of noise tokens. |
|
The number of noise tokens and the number of noise spans and non-noise spans |
|
are determined deterministically as follows: |
|
num_noise_tokens = round(length * noise_density) |
|
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) |
|
Spans alternate between non-noise and noise, beginning with non-noise. |
|
Subject to the above restrictions, all masks are equally likely. |
|
|
|
Args: |
|
length: an int32 scalar (length of the incoming token sequence) |
|
noise_density: a float - approximate density of output mask |
|
mean_noise_span_length: a number |
|
|
|
Returns: |
|
a boolean tensor with shape [length] |
|
""" |
|
|
|
orig_length = length |
|
|
|
|
|
num_noise_tokens = round(length * self.noise_density) |
|
num_nonnoise_tokens = length - num_noise_tokens |
|
|
|
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) |
|
|
|
num_noise_spans = round(min(num_noise_tokens, num_nonnoise_tokens) / self.mean_noise_span_length) |
|
|
|
|
|
num_noise_spans = max(num_noise_spans, 1) |
|
|
|
|
|
def _random_segmentation(num_items, num_segments): |
|
"""Partition a sequence of items randomly into non-empty segments. |
|
Args: |
|
num_items: an integer scalar > 0 |
|
num_segments: an integer scalar in [1, num_items] |
|
Returns: |
|
a Tensor with shape [num_segments] containing positive integers that add |
|
up to num_items |
|
""" |
|
mask_indices = torch.arange(num_items - 1) < (num_segments - 1) |
|
|
|
|
|
idx = torch.randperm(mask_indices.nelement()) |
|
mask_indices = mask_indices.view(-1)[idx].view(mask_indices.size()) |
|
|
|
first_in_segment = torch.cat([torch.tensor([False]), mask_indices]) |
|
segment_id = torch.cumsum(first_in_segment, dim=0) |
|
|
|
_, segment_length = torch.unique(segment_id, return_counts=True) |
|
return segment_length |
|
|
|
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) |
|
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) |
|
|
|
interleaved_span_lengths = torch.reshape( |
|
torch.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2] |
|
) |
|
span_starts = torch.cumsum(interleaved_span_lengths, dim=0)[:-1] |
|
span_start_indicator = torch.zeros((length,), dtype=torch.int8) |
|
span_start_indicator[span_starts] = True |
|
span_num = torch.cumsum(span_start_indicator, dim=0) |
|
is_noise = span_num % 2 == 1 |
|
|
|
return is_noise[:orig_length] |
|
|
|
|
|
def generate_batch_splits(samples_idx: list, batch_size: int, drop_last=True) -> list: |
|
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by |
|
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" |
|
num_samples = len(samples_idx) |
|
if drop_last: |
|
samples_to_remove = num_samples % batch_size |
|
if samples_to_remove != 0: |
|
samples_idx = samples_idx[:-samples_to_remove] |
|
sections_split = num_samples // batch_size |
|
samples_idx = samples_idx.reshape((sections_split, batch_size)) |
|
else: |
|
sections_split = math.ceil(num_samples / batch_size) |
|
samples_idx = torch.split(samples_idx, sections_split) |
|
return samples_idx |
|
|
|
def write_train_metric(summary_writer, train_metrics, train_time, step): |
|
summary_writer.scalar("train_time", train_time, step) |
|
|
|
train_metrics = get_metrics(train_metrics) |
|
for key, vals in train_metrics.items(): |
|
tag = f"train_{key}" |
|
for i, val in enumerate(vals): |
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1) |
|
|
|
|
|
def write_eval_metric(summary_writer, eval_metrics, step): |
|
for metric_name, value in eval_metrics.items(): |
|
summary_writer.scalar(f"eval_{metric_name}", value, step) |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
|
|
|
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
else: |
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
accelerator = Accelerator() |
|
|
|
if model_args.use_auth_token is not None: |
|
warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning) |
|
if model_args.token is not None: |
|
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") |
|
model_args.token = model_args.use_auth_token |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
os.path.exists(training_args.output_dir) |
|
and os.listdir(training_args.output_dir) |
|
and training_args.do_train |
|
and not training_args.overwrite_output_dir |
|
): |
|
raise ValueError( |
|
f"Output directory ({training_args.output_dir}) already exists and is not empty." |
|
"Use --overwrite_output_dir to overcome." |
|
) |
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
level=logging.INFO, |
|
datefmt="[%X]", |
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger.info(f"Training/evaluation parameters {training_args}") |
|
|
|
|
|
set_seed(training_args.seed) |
|
|
|
|
|
if training_args.push_to_hub: |
|
|
|
repo_name = training_args.hub_model_id |
|
if repo_name is None: |
|
repo_name = Path(training_args.output_dir).absolute().name |
|
|
|
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id |
|
|
|
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if data_args.dataset_name is not None: |
|
|
|
datasets = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
cache_dir=model_args.cache_dir, |
|
token=model_args.token, |
|
) |
|
|
|
if "validation" not in datasets.keys(): |
|
datasets["validation"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split=f"train[:{data_args.validation_split_percentage}%]", |
|
cache_dir=model_args.cache_dir, |
|
token=model_args.token, |
|
) |
|
datasets["train"] = load_dataset( |
|
data_args.dataset_name, |
|
data_args.dataset_config_name, |
|
split=f"train[{data_args.validation_split_percentage}%:]", |
|
cache_dir=model_args.cache_dir, |
|
token=model_args.token, |
|
) |
|
else: |
|
data_files = {} |
|
if data_args.train_file is not None: |
|
data_files["train"] = data_args.train_file |
|
if data_args.validation_file is not None: |
|
data_files["validation"] = data_args.validation_file |
|
extension = data_args.train_file.split(".")[-1] |
|
if extension == "txt": |
|
extension = "text" |
|
|
|
datasets = load_dataset( |
|
extension, |
|
data_files=data_files, |
|
cache_dir=model_args.cache_dir, |
|
token=model_args.token, |
|
) |
|
|
|
if "validation" not in datasets.keys(): |
|
datasets["validation"] = load_dataset( |
|
extension, |
|
data_files=data_files, |
|
split=f"train[:{data_args.validation_split_percentage}%]", |
|
cache_dir=model_args.cache_dir, |
|
token=model_args.token, |
|
) |
|
datasets["train"] = load_dataset( |
|
extension, |
|
data_files=data_files, |
|
split=f"train[{data_args.validation_split_percentage}%:]", |
|
cache_dir=model_args.cache_dir, |
|
token=model_args.token, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if model_args.tokenizer_name: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_args.tokenizer_name, |
|
cache_dir=model_args.cache_dir, |
|
use_fast=model_args.use_fast_tokenizer, |
|
token=model_args.token, |
|
) |
|
elif model_args.model_name_or_path: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
use_fast=model_args.use_fast_tokenizer, |
|
token=model_args.token, |
|
) |
|
else: |
|
raise ValueError( |
|
"You are instantiating a new tokenizer from scratch. This is not supported by this script." |
|
"You can do it from another script, save it, and load it from here, using --tokenizer_name." |
|
) |
|
|
|
if model_args.config_name: |
|
config = T5Config.from_pretrained( |
|
model_args.config_name, |
|
cache_dir=model_args.cache_dir, |
|
vocab_size=len(tokenizer), |
|
token=model_args.token, |
|
) |
|
elif model_args.model_name_or_path: |
|
config = T5Config.from_pretrained( |
|
model_args.model_name_or_path, |
|
cache_dir=model_args.cache_dir, |
|
token=model_args.token, |
|
) |
|
else: |
|
config = CONFIG_MAPPING[model_args.model_type]() |
|
logger.warning("You are instantiating a new config instance from scratch.") |
|
|
|
|
|
|
|
if training_args.do_train: |
|
column_names = datasets["train"].column_names |
|
else: |
|
column_names = datasets["validation"].column_names |
|
text_column_name = "text" if "text" in column_names else column_names[0] |
|
|
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) |
|
|
|
|
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples[text_column_name], return_attention_mask=False) |
|
|
|
tokenized_datasets = datasets.map( |
|
tokenize_function, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
) |
|
|
|
|
|
|
|
|
|
expanded_inputs_length, targets_length = compute_input_and_target_lengths( |
|
inputs_length=max_seq_length, |
|
noise_density=data_args.mlm_probability, |
|
mean_noise_span_length=data_args.mean_noise_span_length, |
|
) |
|
|
|
|
|
def group_texts(examples): |
|
|
|
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} |
|
total_length = len(concatenated_examples[list(examples.keys())[0]]) |
|
|
|
|
|
if total_length >= expanded_inputs_length: |
|
total_length = (total_length // expanded_inputs_length) * expanded_inputs_length |
|
|
|
result = { |
|
k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)] |
|
for k, t in concatenated_examples.items() |
|
} |
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenized_datasets = tokenized_datasets.map( |
|
group_texts, |
|
batched=True, |
|
num_proc=data_args.preprocessing_num_workers, |
|
load_from_cache_file=not data_args.overwrite_cache, |
|
) |
|
|
|
|
|
has_tensorboard = is_tensorboard_available() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_args.model_name_or_path: |
|
model = T5ForConditionalGeneration.from_pretrained( |
|
model_args.model_name_or_path, |
|
config=config, |
|
|
|
token=model_args.token, |
|
) |
|
else: |
|
config.vocab_size = len(tokenizer) |
|
model = T5ForConditionalGeneration( |
|
config, |
|
seed=training_args.seed, |
|
) |
|
|
|
|
|
|
|
data_collator = DataCollatorForT5MLM( |
|
tokenizer=tokenizer, |
|
noise_density=data_args.mlm_probability, |
|
mean_noise_span_length=data_args.mean_noise_span_length, |
|
input_length=max_seq_length, |
|
target_length=targets_length, |
|
pad_token_id=model.config.pad_token_id, |
|
decoder_start_token_id=model.config.decoder_start_token_id, |
|
) |
|
|
|
train_dataset = tokenized_datasets["train"] |
|
eval_dataset = tokenized_datasets["validation"] |
|
train_dataloader = DataLoader( |
|
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=training_args.per_device_train_batch_size |
|
) |
|
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=training_args.per_device_eval_batch_size) |
|
|
|
|
|
num_epochs = int(training_args.num_train_epochs) |
|
train_batch_size = int(training_args.per_device_train_batch_size) |
|
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) |
|
eval_batch_size = per_device_eval_batch_size |
|
|
|
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs |
|
|
|
|
|
|
|
no_decay = ["bias", "LayerNorm.weight"] |
|
optimizer_grouped_parameters = [ |
|
{ |
|
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
"weight_decay": training_args.weight_decay, |
|
}, |
|
{ |
|
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
|
"weight_decay": 0.0, |
|
}, |
|
] |
|
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=training_args.learning_rate, betas=(training_args.adam_beta1, training_args.adam_beta2), eps=training_args.adam_epsilon) |
|
|
|
|
|
lr_scheduler = get_linear_schedule_with_warmup( |
|
optimizer=optimizer, |
|
num_warmup_steps= training_args.warmup_steps, |
|
num_training_steps=num_train_steps |
|
) |
|
|
|
|
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( |
|
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler |
|
) |
|
|
|
|
|
if accelerator.distributed_type == DistributedType.TPU: |
|
model.tie_weights() |
|
|
|
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader)) |
|
|
|
|
|
|
|
|
|
|
|
checkpointing_steps = training_args.save_steps |
|
if checkpointing_steps is not None and checkpointing_steps.isdigit(): |
|
checkpointing_steps = int(checkpointing_steps) |
|
|
|
total_batch_size = training_args.per_device_train_batch_size * accelerator.num_processes |
|
|
|
progress_bar = tqdm(range(num_train_steps), disable=not accelerator.is_local_main_process) |
|
completed_steps = 0 |
|
starting_epoch = 0 |
|
|
|
|
|
for epoch in range(starting_epoch, int(training_args.num_train_epochs)): |
|
model.train() |
|
active_dataloader = train_dataloader |
|
for step, batch in enumerate(active_dataloader): |
|
with accelerator.accumulate(model): |
|
outputs = model(**batch) |
|
loss = outputs.loss |
|
|
|
|
|
accelerator.backward(loss) |
|
optimizer.step() |
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
|
|
if accelerator.sync_gradients: |
|
progress_bar.update(1) |
|
completed_steps += 1 |
|
|
|
if isinstance(checkpointing_steps, int): |
|
if completed_steps % checkpointing_steps == 0: |
|
output_dir = f"step_{completed_steps }" |
|
if training_args.output_dir is not None: |
|
output_dir = os.path.join(training_args.output_dir, output_dir) |
|
accelerator.save_state(output_dir) |
|
|
|
if completed_steps >= num_train_steps: |
|
break |
|
|
|
if step % training_args.eval_steps == 0 and step > 0: |
|
model.eval() |
|
losses = [] |
|
for dev_step, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating ...", position=2)): |
|
with torch.no_grad(): |
|
outputs = model(**batch) |
|
|
|
loss = outputs.loss |
|
losses.append(accelerator.gather_for_metrics(loss.repeat(training_args.per_device_eval_batch_size))) |
|
|
|
losses = torch.cat(losses) |
|
try: |
|
eval_loss = torch.mean(losses) |
|
perplexity = math.exp(eval_loss) |
|
except OverflowError: |
|
perplexity = float("inf") |
|
|
|
logger.info(f"step {step}: perplexity: {perplexity}") |
|
|
|
model.eval() |
|
losses = [] |
|
for step, batch in enumerate(eval_dataloader): |
|
with torch.no_grad(): |
|
outputs = model(**batch) |
|
|
|
loss = outputs.loss |
|
losses.append(accelerator.gather_for_metrics(loss.repeat(training_args.per_device_eval_batch_size))) |
|
|
|
losses = torch.cat(losses) |
|
try: |
|
eval_loss = torch.mean(losses) |
|
perplexity = math.exp(eval_loss) |
|
except OverflowError: |
|
perplexity = float("inf") |
|
|
|
logger.info(f"epoch {epoch}: perplexity: {perplexity}") |
|
|
|
if training_args.push_to_hub and epoch < training_args.num_train_epochs - 1: |
|
accelerator.wait_for_everyone() |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
unwrapped_model.save_pretrained( |
|
training_args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save |
|
) |
|
if accelerator.is_main_process: |
|
tokenizer.save_pretrained(training_args.output_dir) |
|
repo.push_to_hub( |
|
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True |
|
) |
|
|
|
if training_args.save_steps == "epoch": |
|
output_dir = f"epoch_{epoch}" |
|
if training_args.output_dir is not None: |
|
output_dir = os.path.join(training_args.output_dir, output_dir) |
|
accelerator.save_state(output_dir) |
|
|
|
if training_args.output_dir is not None: |
|
accelerator.wait_for_everyone() |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
unwrapped_model.save_pretrained( |
|
training_args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save |
|
) |
|
if accelerator.is_main_process: |
|
tokenizer.save_pretrained(training_args.output_dir) |
|
if training_args.push_to_hub: |
|
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True) |
|
|
|
with open(os.path.join(training_args.output_dir, "all_results.json"), "w") as f: |
|
json.dump({"perplexity": perplexity}, f) |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|