DCWIR-Demo / textattack /datasets /huggingface_dataset.py
PFEemp2024's picture
add necessary file
63775f2
"""
HuggingFaceDataset Class
=========================
TextAttack allows users to provide their own dataset or load from HuggingFace.
"""
import collections
import datasets
import textattack
from .dataset import Dataset
def _cb(s):
"""Colors some text blue for printing to the terminal."""
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi")
def get_datasets_dataset_columns(dataset):
"""Common schemas for datasets found in dataset hub."""
schema = set(dataset.column_names)
if {"premise", "hypothesis", "label"} <= schema:
input_columns = ("premise", "hypothesis")
output_column = "label"
elif {"question", "sentence", "label"} <= schema:
input_columns = ("question", "sentence")
output_column = "label"
elif {"sentence1", "sentence2", "label"} <= schema:
input_columns = ("sentence1", "sentence2")
output_column = "label"
elif {"question1", "question2", "label"} <= schema:
input_columns = ("question1", "question2")
output_column = "label"
elif {"question", "sentence", "label"} <= schema:
input_columns = ("question", "sentence")
output_column = "label"
elif {"context", "question", "title", "answers"} <= schema:
# Common schema for SQUAD dataset
input_columns = ("title", "context", "question")
output_column = "answers"
elif {"text", "label"} <= schema:
input_columns = ("text",)
output_column = "label"
elif {"sentence", "label"} <= schema:
input_columns = ("sentence",)
output_column = "label"
elif {"document", "summary"} <= schema:
input_columns = ("document",)
output_column = "summary"
elif {"content", "summary"} <= schema:
input_columns = ("content",)
output_column = "summary"
elif {"label", "review"} <= schema:
input_columns = ("review",)
output_column = "label"
else:
raise ValueError(
f"Unsupported dataset schema {schema}. Try passing your own `dataset_columns` argument."
)
return input_columns, output_column
class HuggingFaceDataset(Dataset):
"""Loads a dataset from 🤗 Datasets and prepares it as a TextAttack dataset.
Args:
name_or_dataset (:obj:`Union[str, datasets.Dataset]`):
The dataset name as :obj:`str` or actual :obj:`datasets.Dataset` object.
If it's your custom :obj:`datasets.Dataset` object, please pass the input and output columns via :obj:`dataset_columns` argument.
subset (:obj:`str`, `optional`, defaults to :obj:`None`):
The subset of the main dataset. Dataset will be loaded as :obj:`datasets.load_dataset(name, subset)`.
split (:obj:`str`, `optional`, defaults to :obj:`"train"`):
The split of the dataset.
dataset_columns (:obj:`tuple(list[str], str))`, `optional`, defaults to :obj:`None`):
Pair of :obj:`list[str]` representing list of input column names (e.g. :obj:`["premise", "hypothesis"]`)
and :obj:`str` representing the output column name (e.g. :obj:`label`). If not set, we will try to automatically determine column names from known designs.
label_map (:obj:`dict[int, int]`, `optional`, defaults to :obj:`None`):
Mapping if output labels of the dataset should be re-mapped. Useful if model was trained with a different label arrangement.
For example, if dataset's arrangement is 0 for `Negative` and 1 for `Positive`, but model's label
arrangement is 1 for `Negative` and 0 for `Positive`, passing :obj:`{0: 1, 1: 0}` will remap the dataset's label to match with model's arrangements.
Could also be used to remap literal labels to numerical labels (e.g. :obj:`{"positive": 1, "negative": 0}`).
label_names (:obj:`list[str]`, `optional`, defaults to :obj:`None`):
List of label names in corresponding order (e.g. :obj:`["World", "Sports", "Business", "Sci/Tech"]` for AG-News dataset).
If not set, labels will printed as is (e.g. "0", "1", ...). This should be set to :obj:`None` for non-classification datasets.
output_scale_factor (:obj:`float`, `optional`, defaults to :obj:`None`):
Factor to divide ground-truth outputs by. Generally, TextAttack goal functions require model outputs between 0 and 1.
Some datasets are regression tasks, in which case this is necessary.
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to shuffle the underlying dataset.
.. note::
Generally not recommended to shuffle the underlying dataset. Shuffling can be performed using DataLoader or by shuffling the order of indices we attack.
"""
def __init__(
self,
name_or_dataset,
subset=None,
split="train",
dataset_columns=None,
label_map=None,
label_names=None,
output_scale_factor=None,
shuffle=False,
):
if isinstance(name_or_dataset, datasets.Dataset):
self._dataset = name_or_dataset
else:
self._name = name_or_dataset
self._subset = subset
self._dataset = datasets.load_dataset(self._name, subset)[split]
subset_print_str = f", subset {_cb(subset)}" if subset else ""
textattack.shared.logger.info(
f"Loading {_cb('datasets')} dataset {_cb(self._name)}{subset_print_str}, split {_cb(split)}."
)
# Input/output column order, like (('premise', 'hypothesis'), 'label')
(
self.input_columns,
self.output_column,
) = dataset_columns or get_datasets_dataset_columns(self._dataset)
if not isinstance(self.input_columns, (list, tuple)):
raise ValueError(
"First element of `dataset_columns` must be a list or a tuple."
)
self.label_map = label_map
self.output_scale_factor = output_scale_factor
if label_names:
self.label_names = label_names
else:
try:
self.label_names = self._dataset.features[self.output_column].names
except (KeyError, AttributeError):
# This happens when the dataset doesn't have 'features' or a 'label' column.
self.label_names = None
# If labels are remapped, the label names have to be remapped as well.
if self.label_names and label_map:
self.label_names = [
self.label_names[self.label_map[i]] for i in self.label_map
]
self.shuffled = shuffle
if shuffle:
self._dataset.shuffle()
def _format_as_dict(self, example):
input_dict = collections.OrderedDict(
[(c, example[c]) for c in self.input_columns]
)
output = example[self.output_column]
if self.label_map:
output = self.label_map[output]
if self.output_scale_factor:
output = output / self.output_scale_factor
return (input_dict, output)
def filter_by_labels_(self, labels_to_keep):
"""Filter items by their labels for classification datasets. Performs
in-place filtering.
Args:
labels_to_keep (:obj:`Union[Set, Tuple, List, Iterable]`):
Set, tuple, list, or iterable of integers representing labels.
"""
if not isinstance(labels_to_keep, set):
labels_to_keep = set(labels_to_keep)
self._dataset = self._dataset.filter(
lambda x: x[self.output_column] in labels_to_keep
)
def __getitem__(self, i):
"""Return i-th sample."""
if isinstance(i, int):
return self._format_as_dict(self._dataset[i])
else:
# `idx` could be a slice or an integer. if it's a slice,
# return the formatted version of the proper slice of the list
return [
self._format_as_dict(self._dataset[j]) for j in range(i.start, i.stop)
]
def shuffle(self):
self._dataset.shuffle()
self.shuffled = True