update for binary
Browse files- bias_auc.py +6 -6
bias_auc.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import evaluate
|
2 |
import datasets
|
3 |
from datasets.features import Sequence, Value, ClassLabel
|
@@ -30,7 +32,6 @@ Args:
|
|
30 |
label list[int]: list containing label index for each item
|
31 |
output list[list[float]]: list of model output values for each
|
32 |
subgroup list[str] (optional): list of subgroups that appear in target to compute metric over
|
33 |
-
|
34 |
Returns (for each subgroup in target):
|
35 |
'Subgroup' : Subgroup AUC score,
|
36 |
'BPSN' : BPSN (Background Positive, Subgroup Negative) AUC,
|
@@ -49,7 +50,6 @@ Example:
|
|
49 |
... [0.4341845214366913, 0.5658154487609863],
|
50 |
... [0.400595098733902, 0.5994048714637756],
|
51 |
... [0.3840397894382477, 0.6159601807594299]]
|
52 |
-
|
53 |
>>> metric = load('Intel/bias_auc')
|
54 |
>>> metric.add_batch(target=target,
|
55 |
label=label,
|
@@ -67,7 +67,7 @@ class BiasAUC(evaluate.Metric):
|
|
67 |
features=datasets.Features(
|
68 |
{
|
69 |
'target': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
70 |
-
'label':
|
71 |
'output': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
|
72 |
}
|
73 |
),
|
@@ -78,7 +78,7 @@ class BiasAUC(evaluate.Metric):
|
|
78 |
"""Returns label and output score from `targets` and `labels`
|
79 |
if `subgroup` is in list of targeted groups found in `targets`
|
80 |
"""
|
81 |
-
target_class = target_class if target_class is not None else
|
82 |
for target, label, result in zip(targets, labels, outputs):
|
83 |
if subgroup in target:
|
84 |
yield label, result[target_class]
|
@@ -89,7 +89,7 @@ class BiasAUC(evaluate.Metric):
|
|
89 |
label is not the same as `target_class`; or (2) `subgroup` is not in list of
|
90 |
targeted groups found in `targets` and label is the same as `target_class`
|
91 |
"""
|
92 |
-
target_class = target_class if target_class is not None else
|
93 |
for target, label, result in zip(targets, labels, outputs):
|
94 |
if not target:
|
95 |
continue
|
@@ -107,7 +107,7 @@ class BiasAUC(evaluate.Metric):
|
|
107 |
targeted groups found in `targets` and label is not the same as `target_class`
|
108 |
"""
|
109 |
# get the index from class
|
110 |
-
target_class = target_class if target_class is not None else
|
111 |
for target, label, result in zip(targets, labels, outputs):
|
112 |
if not target:
|
113 |
continue
|
|
|
1 |
+
%%writefile test_metric/test_metric.py
|
2 |
+
|
3 |
import evaluate
|
4 |
import datasets
|
5 |
from datasets.features import Sequence, Value, ClassLabel
|
|
|
32 |
label list[int]: list containing label index for each item
|
33 |
output list[list[float]]: list of model output values for each
|
34 |
subgroup list[str] (optional): list of subgroups that appear in target to compute metric over
|
|
|
35 |
Returns (for each subgroup in target):
|
36 |
'Subgroup' : Subgroup AUC score,
|
37 |
'BPSN' : BPSN (Background Positive, Subgroup Negative) AUC,
|
|
|
50 |
... [0.4341845214366913, 0.5658154487609863],
|
51 |
... [0.400595098733902, 0.5994048714637756],
|
52 |
... [0.3840397894382477, 0.6159601807594299]]
|
|
|
53 |
>>> metric = load('Intel/bias_auc')
|
54 |
>>> metric.add_batch(target=target,
|
55 |
label=label,
|
|
|
67 |
features=datasets.Features(
|
68 |
{
|
69 |
'target': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
|
70 |
+
'label': Value(dtype='int64', id=None),
|
71 |
'output': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
|
72 |
}
|
73 |
),
|
|
|
78 |
"""Returns label and output score from `targets` and `labels`
|
79 |
if `subgroup` is in list of targeted groups found in `targets`
|
80 |
"""
|
81 |
+
target_class = target_class if target_class is not None else 0
|
82 |
for target, label, result in zip(targets, labels, outputs):
|
83 |
if subgroup in target:
|
84 |
yield label, result[target_class]
|
|
|
89 |
label is not the same as `target_class`; or (2) `subgroup` is not in list of
|
90 |
targeted groups found in `targets` and label is the same as `target_class`
|
91 |
"""
|
92 |
+
target_class = target_class if target_class is not None else 1
|
93 |
for target, label, result in zip(targets, labels, outputs):
|
94 |
if not target:
|
95 |
continue
|
|
|
107 |
targeted groups found in `targets` and label is not the same as `target_class`
|
108 |
"""
|
109 |
# get the index from class
|
110 |
+
target_class = target_class if target_class is not None else 1
|
111 |
for target, label, result in zip(targets, labels, outputs):
|
112 |
if not target:
|
113 |
continue
|