|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
import warnings |
|
from dataclasses import dataclass, field |
|
from enum import Enum |
|
from typing import List, Optional, Union |
|
|
|
import torch |
|
from filelock import FileLock |
|
from torch.utils.data import Dataset |
|
|
|
from ...tokenization_utils_base import PreTrainedTokenizerBase |
|
from ...utils import logging |
|
from ..processors.glue import glue_convert_examples_to_features, glue_output_modes, glue_processors |
|
from ..processors.utils import InputFeatures |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class GlueDataTrainingArguments: |
|
""" |
|
Arguments pertaining to what data we are going to input our model for training and eval. |
|
|
|
Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command |
|
line. |
|
""" |
|
|
|
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())}) |
|
data_dir: str = field( |
|
metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} |
|
) |
|
max_seq_length: int = field( |
|
default=128, |
|
metadata={ |
|
"help": ( |
|
"The maximum total input sequence length after tokenization. Sequences longer " |
|
"than this will be truncated, sequences shorter will be padded." |
|
) |
|
}, |
|
) |
|
overwrite_cache: bool = field( |
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
|
) |
|
|
|
def __post_init__(self): |
|
self.task_name = self.task_name.lower() |
|
|
|
|
|
class Split(Enum): |
|
train = "train" |
|
dev = "dev" |
|
test = "test" |
|
|
|
|
|
class GlueDataset(Dataset): |
|
""" |
|
This will be superseded by a framework-agnostic approach soon. |
|
""" |
|
|
|
args: GlueDataTrainingArguments |
|
output_mode: str |
|
features: List[InputFeatures] |
|
|
|
def __init__( |
|
self, |
|
args: GlueDataTrainingArguments, |
|
tokenizer: PreTrainedTokenizerBase, |
|
limit_length: Optional[int] = None, |
|
mode: Union[str, Split] = Split.train, |
|
cache_dir: Optional[str] = None, |
|
): |
|
warnings.warn( |
|
"This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets " |
|
"library. You can have a look at this example script for pointers: " |
|
"https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py", |
|
FutureWarning, |
|
) |
|
self.args = args |
|
self.processor = glue_processors[args.task_name]() |
|
self.output_mode = glue_output_modes[args.task_name] |
|
if isinstance(mode, str): |
|
try: |
|
mode = Split[mode] |
|
except KeyError: |
|
raise KeyError("mode is not a valid split name") |
|
|
|
cached_features_file = os.path.join( |
|
cache_dir if cache_dir is not None else args.data_dir, |
|
f"cached_{mode.value}_{tokenizer.__class__.__name__}_{args.max_seq_length}_{args.task_name}", |
|
) |
|
label_list = self.processor.get_labels() |
|
if args.task_name in ["mnli", "mnli-mm"] and tokenizer.__class__.__name__ in ( |
|
"RobertaTokenizer", |
|
"RobertaTokenizerFast", |
|
"XLMRobertaTokenizer", |
|
"BartTokenizer", |
|
"BartTokenizerFast", |
|
): |
|
|
|
label_list[1], label_list[2] = label_list[2], label_list[1] |
|
self.label_list = label_list |
|
|
|
|
|
|
|
lock_path = cached_features_file + ".lock" |
|
with FileLock(lock_path): |
|
if os.path.exists(cached_features_file) and not args.overwrite_cache: |
|
start = time.time() |
|
self.features = torch.load(cached_features_file) |
|
logger.info( |
|
f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start |
|
) |
|
else: |
|
logger.info(f"Creating features from dataset file at {args.data_dir}") |
|
|
|
if mode == Split.dev: |
|
examples = self.processor.get_dev_examples(args.data_dir) |
|
elif mode == Split.test: |
|
examples = self.processor.get_test_examples(args.data_dir) |
|
else: |
|
examples = self.processor.get_train_examples(args.data_dir) |
|
if limit_length is not None: |
|
examples = examples[:limit_length] |
|
self.features = glue_convert_examples_to_features( |
|
examples, |
|
tokenizer, |
|
max_length=args.max_seq_length, |
|
label_list=label_list, |
|
output_mode=self.output_mode, |
|
) |
|
start = time.time() |
|
torch.save(self.features, cached_features_file) |
|
|
|
logger.info( |
|
f"Saving features into cached file {cached_features_file} [took {time.time() - start:.3f} s]" |
|
) |
|
|
|
def __len__(self): |
|
return len(self.features) |
|
|
|
def __getitem__(self, i) -> InputFeatures: |
|
return self.features[i] |
|
|
|
def get_labels(self): |
|
return self.label_list |
|
|