Spaces:
Runtime error
Runtime error
""" | |
AugmentCommand class | |
=========================== | |
""" | |
from argparse import ArgumentDefaultsHelpFormatter, ArgumentError, ArgumentParser | |
import csv | |
import os | |
import time | |
import tqdm | |
import textattack | |
from textattack.augment_args import AUGMENTATION_RECIPE_NAMES | |
from textattack.commands import TextAttackCommand | |
class AugmentCommand(TextAttackCommand): | |
"""The TextAttack attack module: | |
A command line parser to run data augmentation from user | |
specifications. | |
""" | |
def run(self, args): | |
"""Reads in a CSV, performs augmentation, and outputs an augmented CSV. | |
Preserves all columns except for the input (augmneted) column. | |
""" | |
args = textattack.AugmenterArgs(**vars(args)) | |
if args.interactive: | |
print("\nRunning in interactive mode...\n") | |
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( | |
pct_words_to_swap=args.pct_words_to_swap, | |
transformations_per_example=args.transformations_per_example, | |
high_yield=args.high_yield, | |
fast_augment=args.fast_augment, | |
enable_advanced_metrics=args.enable_advanced_metrics, | |
) | |
print("--------------------------------------------------------") | |
while True: | |
print( | |
'\nEnter a sentence to augment, "q" to quit, "c" to view/change arguments:\n' | |
) | |
text = input() | |
if text == "q": | |
break | |
elif text == "c": | |
print( | |
f"\nCurrent Arguments:\n\n\t augmentation recipe: {args.recipe}, " | |
f"\n\t pct_words_to_swap: {args.pct_words_to_swap}, " | |
f"\n\t transformations_per_example: {args.transformations_per_example}\n" | |
) | |
change = input( | |
"Enter 'c' again to change arguments, any other keys to opt out\n" | |
) | |
if change == "c": | |
print("\nChanging augmenter arguments...\n") | |
recipe = input( | |
"\tAugmentation recipe name ('r' to see available recipes): " | |
) | |
if recipe == "r": | |
recipe_display = " ".join(AUGMENTATION_RECIPE_NAMES.keys()) | |
print(f"\n\t{recipe_display}\n") | |
args.recipe = input("\tAugmentation recipe name: ") | |
else: | |
args.recipe = recipe | |
args.pct_words_to_swap = float( | |
input("\tPercentage of words to swap (0.0 ~ 1.0): ") | |
) | |
args.transformations_per_example = int( | |
input("\tTransformations per input example: ") | |
) | |
print("\nGenerating new augmenter...\n") | |
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( | |
pct_words_to_swap=args.pct_words_to_swap, | |
transformations_per_example=args.transformations_per_example, | |
) | |
print( | |
"--------------------------------------------------------" | |
) | |
continue | |
elif not text: | |
continue | |
print("\nAugmenting...\n") | |
print("--------------------------------------------------------") | |
if args.enable_advanced_metrics: | |
results = augmenter.augment(text) | |
print("Augmentations:\n") | |
for augmentation in results[0]: | |
print(augmentation, "\n") | |
print() | |
print( | |
f"Average Original Perplexity Score: {results[1]['avg_original_perplexity']}" | |
) | |
print( | |
f"Average Augment Perplexity Score: {results[1]['avg_attack_perplexity']}" | |
) | |
print( | |
f"Average Augment USE Score: {results[2]['avg_attack_use_score']}\n" | |
) | |
else: | |
for augmentation in augmenter.augment(text): | |
print(augmentation, "\n") | |
print("--------------------------------------------------------") | |
else: | |
textattack.shared.utils.set_seed(args.random_seed) | |
start_time = time.time() | |
if not (args.input_csv and args.input_column and args.output_csv): | |
raise ArgumentError( | |
"The following arguments are required: --csv, --input-column/--i" | |
) | |
# Validate input/output paths. | |
if not os.path.exists(args.input_csv): | |
raise FileNotFoundError(f"Can't find CSV at location {args.input_csv}") | |
if os.path.exists(args.output_csv): | |
if args.overwrite: | |
textattack.shared.logger.info( | |
f"Preparing to overwrite {args.output_csv}." | |
) | |
else: | |
raise OSError( | |
f"Outfile {args.output_csv} exists and --overwrite not set." | |
) | |
# Read in CSV file as a list of dictionaries. Use the CSV sniffer to | |
# try and automatically infer the correct CSV format. | |
csv_file = open(args.input_csv, "r") | |
# mark where commas and quotes occur within the text value | |
def markQuotes(lines): | |
for row in lines: | |
row = row.replace('"', '"/') | |
yield row | |
dialect = csv.Sniffer().sniff(csv_file.readline(), delimiters=";,") | |
csv_file.seek(0) | |
rows = [ | |
row | |
for row in csv.DictReader( | |
markQuotes(csv_file), | |
dialect=dialect, | |
skipinitialspace=True, | |
) | |
] | |
# replace markings with quotations and commas | |
for row in rows: | |
for item in row: | |
i = 0 | |
while i < len(row[item]): | |
if row[item][i] == "/": | |
if row[item][i - 1] == '"': | |
row[item] = row[item][:i] + row[item][i + 1 :] | |
else: | |
row[item] = row[item][:i] + '"' + row[item][i + 1 :] | |
i += 1 | |
# Validate input column. | |
row_keys = set(rows[0].keys()) | |
if args.input_column not in row_keys: | |
raise ValueError( | |
f"Could not find input column {args.input_column} in CSV. Found keys: {row_keys}" | |
) | |
textattack.shared.logger.info( | |
f"Read {len(rows)} rows from {args.input_csv}. Found columns {row_keys}." | |
) | |
augmenter = eval(AUGMENTATION_RECIPE_NAMES[args.recipe])( | |
pct_words_to_swap=args.pct_words_to_swap, | |
transformations_per_example=args.transformations_per_example, | |
high_yield=args.high_yield, | |
fast_augment=args.fast_augment, | |
) | |
output_rows = [] | |
for row in tqdm.tqdm(rows, desc="Augmenting rows"): | |
text_input = row[args.input_column] | |
if not args.exclude_original: | |
output_rows.append(row) | |
for augmentation in augmenter.augment(text_input): | |
augmented_row = row.copy() | |
augmented_row[args.input_column] = augmentation | |
output_rows.append(augmented_row) | |
# Print to file. | |
with open(args.output_csv, "w") as outfile: | |
csv_writer = csv.writer( | |
outfile, delimiter=",", quotechar="/", quoting=csv.QUOTE_MINIMAL | |
) | |
# Write header. | |
csv_writer.writerow(output_rows[0].keys()) | |
# Write rows. | |
for row in output_rows: | |
csv_writer.writerow(row.values()) | |
textattack.shared.logger.info( | |
f"Wrote {len(output_rows)} augmentations to {args.output_csv} in {time.time() - start_time}s." | |
) | |
# Remove extra markings in output file | |
with open(args.output_csv, "r") as file: | |
data = file.readlines() | |
for i in range(len(data)): | |
data[i] = data[i].replace("/", "") | |
with open(args.output_csv, "w") as file: | |
file.writelines(data) | |
def register_subcommand(main_parser: ArgumentParser): | |
parser = main_parser.add_parser( | |
"augment", | |
help="augment text data", | |
formatter_class=ArgumentDefaultsHelpFormatter, | |
) | |
parser = textattack.AugmenterArgs._add_parser_args(parser) | |
parser.set_defaults(func=AugmentCommand()) | |