Spaces:
Runtime error
Runtime error
""" | |
PeekDatasetCommand class | |
============================== | |
""" | |
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser | |
import collections | |
import re | |
import numpy as np | |
import textattack | |
from textattack.commands import TextAttackCommand | |
def _cb(s): | |
return textattack.shared.utils.color_text(str(s), color="blue", method="ansi") | |
logger = textattack.shared.logger | |
class PeekDatasetCommand(TextAttackCommand): | |
"""The peek dataset module: | |
Takes a peek into a dataset in textattack. | |
""" | |
def run(self, args): | |
UPPERCASE_LETTERS_REGEX = re.compile("[A-Z]") | |
dataset_args = textattack.DatasetArgs(**vars(args)) | |
dataset = textattack.DatasetArgs._create_dataset_from_args(dataset_args) | |
num_words = [] | |
attacked_texts = [] | |
data_all_lowercased = True | |
outputs = [] | |
for inputs, output in dataset: | |
at = textattack.shared.AttackedText(inputs) | |
if data_all_lowercased: | |
# Test if any of the letters in the string are lowercase. | |
if re.search(UPPERCASE_LETTERS_REGEX, at.text): | |
data_all_lowercased = False | |
attacked_texts.append(at) | |
num_words.append(len(at.words)) | |
outputs.append(output) | |
logger.info(f"Number of samples: {_cb(len(attacked_texts))}") | |
logger.info("Number of words per input:") | |
num_words = np.array(num_words) | |
logger.info(f'\t{("total:").ljust(8)} {_cb(num_words.sum())}') | |
mean_words = f"{num_words.mean():.2f}" | |
logger.info(f'\t{("mean:").ljust(8)} {_cb(mean_words)}') | |
std_words = f"{num_words.std():.2f}" | |
logger.info(f'\t{("std:").ljust(8)} {_cb(std_words)}') | |
logger.info(f'\t{("min:").ljust(8)} {_cb(num_words.min())}') | |
logger.info(f'\t{("max:").ljust(8)} {_cb(num_words.max())}') | |
logger.info(f"Dataset lowercased: {_cb(data_all_lowercased)}") | |
logger.info("First sample:") | |
print(attacked_texts[0].printable_text(), "\n") | |
logger.info("Last sample:") | |
print(attacked_texts[-1].printable_text(), "\n") | |
logger.info(f"Found {len(set(outputs))} distinct outputs.") | |
if len(outputs) < 20: | |
print(sorted(set(outputs))) | |
logger.info("Most common outputs:") | |
for i, (key, value) in enumerate(collections.Counter(outputs).most_common(20)): | |
print("\t", str(key)[:5].ljust(5), f" ({value})") | |
def register_subcommand(main_parser: ArgumentParser): | |
parser = main_parser.add_parser( | |
"peek-dataset", | |
help="show main statistics about a dataset", | |
formatter_class=ArgumentDefaultsHelpFormatter, | |
) | |
parser = textattack.DatasetArgs._add_parser_args(parser) | |
parser.set_defaults(func=PeekDatasetCommand()) | |