Spaces:
Runtime error
Runtime error
import textattack | |
import transformers | |
from FlowCorrector import Flow_Corrector | |
import torch | |
import torch.nn.functional as F | |
def count_matching_classes(original, corrected): | |
if len(original) != len(corrected): | |
raise ValueError("Arrays must have the same length") | |
matching_count = 0 | |
for i in range(len(corrected)): | |
if original[i] == corrected[i]: | |
matching_count += 1 | |
return matching_count | |
if __name__ == "main" : | |
# Load model, tokenizer, and model_wrapper | |
model = transformers.AutoModelForSequenceClassification.from_pretrained( | |
"textattack/bert-base-uncased-ag-news" | |
) | |
tokenizer = transformers.AutoTokenizer.from_pretrained( | |
"textattack/bert-base-uncased-ag-news" | |
) | |
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
# Construct our four components for `Attack` | |
from textattack.constraints.pre_transformation import ( | |
RepeatModification, | |
StopwordModification, | |
) | |
from textattack.constraints.semantics import WordEmbeddingDistance | |
from textattack.transformations import WordSwapEmbedding | |
from textattack.search_methods import GreedyWordSwapWIR | |
goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper) | |
constraints = [ | |
RepeatModification(), | |
StopwordModification(), | |
WordEmbeddingDistance(min_cos_sim=0.9), | |
] | |
transformation = WordSwapEmbedding(max_candidates=50) | |
search_method = GreedyWordSwapWIR(wir_method="weighted-saliency") | |
# Construct the actual attack | |
attack = textattack.Attack(goal_function, constraints, transformation, search_method) | |
attack.cuda_() | |
# intialisation de coreecteur | |
corrector = Flow_Corrector( | |
attack, | |
word_rank_file="en_full_ranked.json", | |
word_freq_file="en_full_freq.json", | |
) | |
# All these texts are adverserial ones | |
with open('perturbed_texts_ag_news.txt', 'r') as f: | |
detected_texts = [line.strip() for line in f] | |
#These are orginal texts in same order of adverserial ones | |
with open("original_texts_ag_news.txt", "r") as f: | |
original_texts = [line.strip() for line in f] | |
victim_model = attack.goal_function.model | |
# getting original labels for benchmarking later | |
original_classes = [ | |
torch.argmax(F.softmax(victim_model(original_text), dim=1)).item() | |
for original_text in original_texts | |
] | |
""" 0 :World | |
1 : Sports | |
2 : Business | |
3 : Sci/Tech""" | |
corrected_classes = corrector.correct(original_texts) | |
print(f"match {count_matching_classes()}") | |