Spaces:
Runtime error
Runtime error
File size: 1,553 Bytes
daeb223 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import datasets
from transformers.trainer_utils import EvalPrediction
accuracy = datasets.load_metric("accuracy").compute
precision = datasets.load_metric("precision").compute
recall = datasets.load_metric("recall").compute
f1 = datasets.load_metric("f1").compute
squad_v2 = datasets.load_metric("squad_v2").compute
def compute_classification_metric(p: EvalPrediction):
"""
'predictions': Value(dtype='int32', id=None)
'references': Value(dtype='int32', id=None)
"""
predictions = p.predictions.argmax(axis=1)
references = p.label_ids
metric = accuracy(predictions=predictions, references=references)
metric.update(precision(predictions=predictions, references=references))
metric.update(recall(predictions=predictions, references=references))
metric.update(f1(predictions=predictions, references=references))
return metric
def compute_squad_v2(p: EvalPrediction):
"""
'predictions': {
'id': Value(dtype='string', id=None),
'prediction_text': Value(dtype='string', id=None),
'no_answer_probability': Value(dtype='float32', id=None)
}
'references': {
'id': Value(dtype='string', id=None),
'answers': Sequence(
feature={
'text': Value(dtype='string', id=None),
'answer_start': Value(dtype='int32', id=None)
},
length=-1, id=None
)
}
"""
predictions = p.predictions
references = p.label_ids
return squad_v2(predictions=predictions, references=references) |