|
import numpy as np |
|
import torch |
|
|
|
from evaluate import load as load_metric |
|
|
|
from sklearn.metrics import accuracy_score, f1_score |
|
from tqdm.auto import tqdm |
|
|
|
MAX_TARGET_LENGTH = 128 |
|
|
|
|
|
sacrebleu = load_metric('sacrebleu') |
|
rouge = load_metric('rouge') |
|
meteor = load_metric('meteor') |
|
bertscore = load_metric('bertscore') |
|
|
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
def flatten_list(l): |
|
""" |
|
Utility function to convert a list of lists into a flattened list |
|
|
|
Params: |
|
l (list of lists): list to be flattened |
|
Returns: |
|
A flattened list with the elements of the original list |
|
""" |
|
return [item for sublist in l for item in sublist] |
|
|
|
def extract_feedback(predictions): |
|
""" |
|
Utility function to extract the feedback from the predictions of the model |
|
|
|
Params: |
|
predictions (list): complete model predictions |
|
Returns: |
|
feedback (list): extracted feedback from the model's predictions |
|
""" |
|
feedback = [] |
|
|
|
for pred in predictions: |
|
try: |
|
fb = pred.split(':', 1)[1] |
|
except IndexError: |
|
try: |
|
if pred.lower().startswith('partially correct'): |
|
fb = pred.split(' ', 1)[2] |
|
else: |
|
fb = pred.split(' ', 1)[1] |
|
except IndexError: |
|
fb = pred |
|
feedback.append(fb.strip()) |
|
|
|
return feedback |
|
|
|
def extract_labels(predictions): |
|
""" |
|
Utility function to extract the labels from the predictions of the model |
|
|
|
Params: |
|
predictions (list): complete model predictions |
|
Returns: |
|
feedback (list): extracted labels from the model's predictions |
|
""" |
|
labels = [] |
|
for pred in predictions: |
|
if pred.lower().startswith('correct'): |
|
label = 'Correct' |
|
elif pred.lower().startswith('partially correct'): |
|
label = 'Partially correct' |
|
elif pred.lower().startswith('incorrect'): |
|
label = 'Incorrect' |
|
else: |
|
label = 'Unknown label' |
|
labels.append(label) |
|
|
|
return labels |
|
|
|
def compute_metrics(predictions, labels): |
|
""" |
|
Compute evaluation metrics from the predictions of the model |
|
|
|
Params: |
|
predictions (list): complete model predictions |
|
labels (list): golden labels (previously tokenized) |
|
Returns: |
|
results (dict): dictionary with the computed evaluation metrics |
|
predictions (list): list of the decoded predictions of the model |
|
""" |
|
|
|
predicted_feedback = extract_feedback(predictions) |
|
predicted_labels = extract_labels(predictions) |
|
|
|
|
|
reference_feedback = [x.split('Feedback:', 1)[1].strip() for x in labels] |
|
reference_labels = [x.split('Feedback:', 1)[0].strip() for x in labels] |
|
|
|
|
|
sacrebleu_score = sacrebleu.compute(predictions=predicted_feedback, references=[[x] for x in reference_feedback])['score'] |
|
rouge_score = rouge.compute(predictions=predicted_feedback, references=reference_feedback)['rouge2'] |
|
meteor_score = meteor.compute(predictions=predicted_feedback, references=reference_feedback)['meteor'] |
|
bert_score = bertscore.compute( |
|
predictions=predicted_feedback, |
|
references=reference_feedback, |
|
lang='en', |
|
rescale_with_baseline=True) |
|
|
|
|
|
reference_labels_np = np.array(reference_labels) |
|
accuracy = accuracy_score(reference_labels_np, predicted_labels) |
|
f1_weighted = f1_score(reference_labels_np, predicted_labels, average='weighted') |
|
f1_macro = f1_score( |
|
reference_labels_np, |
|
predicted_labels, |
|
average='macro', |
|
labels=['Incorrect', 'Partially correct', 'Correct']) |
|
|
|
results = { |
|
'sacrebleu': sacrebleu_score, |
|
'rouge': rouge_score, |
|
'meteor': meteor_score, |
|
'bert_score': np.array(bert_score['f1']).mean().item(), |
|
'accuracy': accuracy, |
|
'f1_weighted': f1_weighted, |
|
'f1_macro': f1_macro |
|
} |
|
|
|
return results |
|
|
|
def evaluate(model, tokenizer, dataloader): |
|
""" |
|
Evaluate model on the given dataset |
|
|
|
Params: |
|
model (PreTrainedModel): seq2seq model |
|
tokenizer (PreTrainedTokenizer): tokenizer from HuggingFace |
|
dataloader (torch Dataloader): dataloader of the dataset to be used for evaluation |
|
Returns: |
|
results (dict): dictionary with the computed evaluation metrics |
|
predictions (list): list of the decoded predictions of the model |
|
""" |
|
decoded_preds, decoded_labels = [], [] |
|
|
|
model.eval() |
|
|
|
for batch in tqdm(dataloader): |
|
with torch.no_grad(): |
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
|
generated_tokens = model.generate( |
|
batch['input_ids'], |
|
attention_mask=batch['attention_mask'], |
|
max_length=MAX_TARGET_LENGTH |
|
) |
|
|
|
labels_batch = batch['labels'] |
|
|
|
|
|
decoded_preds_batch = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
decoded_labels_batch = tokenizer.batch_decode(labels_batch, skip_special_tokens=True) |
|
|
|
decoded_preds.append(decoded_preds_batch) |
|
decoded_labels.append(decoded_labels_batch) |
|
|
|
|
|
predictions = flatten_list(decoded_preds) |
|
labels = flatten_list(decoded_labels) |
|
|
|
|
|
results = compute_metrics(predictions, labels) |
|
|
|
return results, predictions |