resrer-pegasus-x / eval.py
seonglae's picture
Training in progress, step 500
a9082f6
raw
history blame
No virus
2.67 kB
import re
import string
import unicodedata
from evaluate import evaluator, QuestionAnsweringEvaluator
from datasets import load_dataset
def evaluate_dataset(id: str, subset: str, metric: str = 'squad_v2',
question_col: str = 'question', context_col: str = 'retrieved', predict_col: str = 'predicted',
id_col: str = 'question', label_col: str = 'answer', labeling: bool = True):
referee: QuestionAnsweringEvaluator = evaluator("question-answering")
referee.PIPELINE_KWARGS["handle_impossible_answer"] = True
# Dataset
dataset = load_dataset(id, subset)
dataset_list = list(dataset['train'])
metric_input, qa = referee.prepare_data(
dataset['train'], question_col, context_col, id_col, label_col)
# References
if labeling:
for i, reference in enumerate(metric_input['references']):
starts = [qa['context'][i].find(answer)
for answer in reference['answers']]
reference['answers'] = {
'answer_start': starts, 'text': reference['answers']}
# Prediction
metric_input['predictions'] = []
for row in dataset_list:
result = {
'prediction_text': row[predict_col], 'id': row[id_col]}
if metric == 'squad_v2':
result['no_answer_probability'] = 0.
metric_input['predictions'].append(result)
metric_module = referee.prepare_metric(metric)
results = referee.compute_metric(metric_module, metric_inputs=metric_input)
return results
def evaluate_dataset_manual(id: str, subset: str):
dataset = load_dataset(id, subset)
dataset_list = list(dataset['train'])
for row in dataset_list:
row['score'] = max([regex_match_score(row['predicted'], answer)
for answer in row['answer']])
score = sum([row['score'] for row in dataset_list]) / len(dataset_list)
return score
def normalize_answer(s):
"""Normalize answer."""
s = unicodedata.normalize("NFD", s)
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def regex_match_score(prediction, ground_truth):
try:
regex = re.compile(ground_truth,
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE)
return regex.match(prediction) is not None
except re.error:
return False