Spaces:
Running
Running
add mmlu and cmmlu
Browse files
tasks.py
CHANGED
@@ -3,19 +3,14 @@ from datasets import load_dataset, Dataset
|
|
3 |
from functools import cached_property
|
4 |
from tqdm.auto import tqdm
|
5 |
from typing import Any, Optional, Protocol, Iterable, Callable
|
|
|
|
|
|
|
6 |
|
7 |
-
from .utils import
|
8 |
-
NUMERIC_IN_ZH,
|
9 |
-
extract_choice_ans,
|
10 |
-
extract_numeric,
|
11 |
-
get_answer,
|
12 |
-
is_equiv,
|
13 |
-
)
|
14 |
|
15 |
from evaluate import load
|
16 |
|
17 |
-
TextGenerationPipeline = Callable[[Iterable[str]], list[str]]
|
18 |
-
|
19 |
|
20 |
def fake_pipeline(prompts: Iterable[str]) -> list[str]:
|
21 |
return [prompt for prompt in tqdm(prompts)]
|
@@ -30,14 +25,25 @@ class Task:
|
|
30 |
input_column: str = "question"
|
31 |
label_column: str = "answer"
|
32 |
prompt: Optional[Callable | str] = None
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
self.dataset_name
|
38 |
if isinstance(self.dataset_name, str)
|
39 |
-
else self.dataset_name
|
40 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
@cached_property
|
43 |
def samples(self):
|
@@ -49,20 +55,38 @@ class Task:
|
|
49 |
*self.dataset_name
|
50 |
if isinstance(self.dataset_name, tuple)
|
51 |
else self.dataset_name,
|
52 |
-
split=self.split,
|
53 |
)
|
|
|
54 |
if self.prompt is not None:
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
lambda example: {
|
57 |
-
self.input_column: self.
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
}
|
61 |
-
if isinstance(self.prompt, str)
|
62 |
-
else self.prompt(example),
|
63 |
)
|
64 |
|
65 |
-
return
|
66 |
|
67 |
@cached_property
|
68 |
def metric(self):
|
@@ -73,14 +97,44 @@ class Task:
|
|
73 |
)
|
74 |
return metric
|
75 |
|
76 |
-
def run(
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
|
83 |
class Metrics:
|
|
|
|
|
|
|
84 |
def gsm8k(responses: list[str], answers: list[str | int]):
|
85 |
scores = []
|
86 |
for response, answer in zip(responses, answers):
|
@@ -112,26 +166,287 @@ class Metrics:
|
|
112 |
scores.append(1.0 * (pred == gold))
|
113 |
return scores
|
114 |
|
115 |
-
def gsm8k_zh(responses: list[str], answers: list[str]):
|
116 |
-
scores = []
|
117 |
-
for response, answer in zip(responses, answers):
|
118 |
-
pred = extract_numeric(response, pattern=NUMERIC_IN_ZH)
|
119 |
-
gold = extract_numeric(answer)
|
120 |
-
scores.append(1.0 * (pred == gold))
|
121 |
-
return scores
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
return scores
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from functools import cached_property
|
4 |
from tqdm.auto import tqdm
|
5 |
from typing import Any, Optional, Protocol, Iterable, Callable
|
6 |
+
import logging
|
7 |
+
import pandas as pd
|
8 |
+
from functools import partial
|
9 |
|
10 |
+
from .utils import *
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
from evaluate import load
|
13 |
|
|
|
|
|
14 |
|
15 |
def fake_pipeline(prompts: Iterable[str]) -> list[str]:
|
16 |
return [prompt for prompt in tqdm(prompts)]
|
|
|
25 |
input_column: str = "question"
|
26 |
label_column: str = "answer"
|
27 |
prompt: Optional[Callable | str] = None
|
28 |
+
few_shot: int = 0
|
29 |
+
few_shot_from: Optional[str] = None
|
30 |
+
# results: dict[str, Any] = field(default_factory=dict)
|
31 |
|
32 |
+
def __post_init__(self):
|
33 |
+
names = (
|
34 |
+
[self.dataset_name]
|
|
|
35 |
if isinstance(self.dataset_name, str)
|
36 |
+
else list(self.dataset_name)
|
37 |
+
)
|
38 |
+
names[0] = names[0].split("/")[-1]
|
39 |
+
|
40 |
+
self.name = "-".join(names) + f"-{self.split}"
|
41 |
+
if isinstance(self.prompt, str):
|
42 |
+
self.prompt = lambda example: {
|
43 |
+
self.input_column: self.prompt.format(
|
44 |
+
input_column=example[self.input_column]
|
45 |
+
)
|
46 |
+
}
|
47 |
|
48 |
@cached_property
|
49 |
def samples(self):
|
|
|
55 |
*self.dataset_name
|
56 |
if isinstance(self.dataset_name, tuple)
|
57 |
else self.dataset_name,
|
58 |
+
# split=self.split,
|
59 |
)
|
60 |
+
test_ds = ds[self.split]
|
61 |
if self.prompt is not None:
|
62 |
+
test_ds = test_ds.map(self.prompt)
|
63 |
+
|
64 |
+
if self.few_shot:
|
65 |
+
if self.few_shot_from is None:
|
66 |
+
for name in ["train", "validation", "val", "dev"]:
|
67 |
+
if name in ds:
|
68 |
+
self.few_shot_from = name
|
69 |
+
break
|
70 |
+
|
71 |
+
shots = ds[self.few_shot_from].select(range(self.few_shot))
|
72 |
+
if self.prompt is not None:
|
73 |
+
shots = shots.map(self.prompt)
|
74 |
+
|
75 |
+
shots = shots.map(
|
76 |
lambda example: {
|
77 |
+
self.input_column: example[self.input_column]
|
78 |
+
+ example[self.label_column],
|
79 |
+
}
|
80 |
+
)[self.input_column]
|
81 |
+
few_shot_prompts = "\n".join(shots)
|
82 |
+
|
83 |
+
test_ds = test_ds.map(
|
84 |
+
lambda example: {
|
85 |
+
self.input_column: few_shot_prompts + example[self.input_column],
|
86 |
}
|
|
|
|
|
87 |
)
|
88 |
|
89 |
+
return test_ds
|
90 |
|
91 |
@cached_property
|
92 |
def metric(self):
|
|
|
97 |
)
|
98 |
return metric
|
99 |
|
100 |
+
def run(
|
101 |
+
self,
|
102 |
+
pipeline,
|
103 |
+
):
|
104 |
+
if (outputs := pipeline(self.samples)) is None:
|
105 |
+
logging.warning("pipeline returns None")
|
106 |
+
return
|
107 |
+
self.outputs = outputs
|
108 |
+
try:
|
109 |
+
result = self.metric._compute(
|
110 |
+
responses=outputs, references=self.dataset[self.label_column]
|
111 |
+
)
|
112 |
+
except Exception as e:
|
113 |
+
result = self.metric.compute(
|
114 |
+
responses=outputs, references=self.dataset[self.label_column]
|
115 |
+
)
|
116 |
+
# if log:
|
117 |
+
# name = name or pipeline.__name__
|
118 |
+
# self.results[name] = result
|
119 |
+
|
120 |
+
return result
|
121 |
+
|
122 |
+
|
123 |
+
def multichoice(responses: Any, references: list[str]):
|
124 |
+
if isinstance(responses[0], str):
|
125 |
+
responses = [extract_choice(response) for response in responses]
|
126 |
+
else:
|
127 |
+
responses = decode_choice(responses)
|
128 |
+
|
129 |
+
return [
|
130 |
+
int(response == reference) for reference, response in zip(references, responses)
|
131 |
+
]
|
132 |
|
133 |
|
134 |
class Metrics:
|
135 |
+
cmmlu = multichoice
|
136 |
+
mmlu = multichoice
|
137 |
+
|
138 |
def gsm8k(responses: list[str], answers: list[str | int]):
|
139 |
scores = []
|
140 |
for response, answer in zip(responses, answers):
|
|
|
166 |
scores.append(1.0 * (pred == gold))
|
167 |
return scores
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
+
class CMMLU:
|
171 |
+
def prompt_cmmlu(example, chat=False):
|
172 |
+
prefix = "以下是一道多项选择题,请从A、B、C和D中选择最合适的答案作为这个问题的答案。\n\n" if chat else "问题:"
|
173 |
+
prompt = prefix + example["Question"]
|
174 |
+
for choice in list("ABCD"):
|
175 |
+
prompt += f"\n{choice}. {example[choice]}"
|
|
|
176 |
|
177 |
+
prompt += "\n答案:"
|
178 |
+
return {"prompt": prompt}
|
179 |
+
|
180 |
+
subcategories = {
|
181 |
+
"agronomy": ["other"],
|
182 |
+
"anatomy": ["biology"],
|
183 |
+
"ancient_chinese": ["linguistics", "china specific"],
|
184 |
+
"arts": ["arts"],
|
185 |
+
"astronomy": ["physics"],
|
186 |
+
"business_ethics": ["business"],
|
187 |
+
"chinese_civil_service_exam": ["politics", "china specific"],
|
188 |
+
"chinese_driving_rule": ["other", "china specific"],
|
189 |
+
"chinese_food_culture": ["culture", "china specific"],
|
190 |
+
"chinese_foreign_policy": ["politics", "china specific"],
|
191 |
+
"chinese_history": ["history", "china specific"],
|
192 |
+
"chinese_literature": ["literature", "china specific"],
|
193 |
+
"chinese_teacher_qualification": ["education", "china specific"],
|
194 |
+
"college_actuarial_science": ["math"],
|
195 |
+
"college_education": ["education"],
|
196 |
+
"college_engineering_hydrology": ["engineering"],
|
197 |
+
"college_law": ["law"],
|
198 |
+
"college_mathematics": ["math"],
|
199 |
+
"college_medical_statistics": ["statistics"],
|
200 |
+
"clinical_knowledge": ["other"],
|
201 |
+
"college_medicine": ["other"],
|
202 |
+
"computer_science": ["computer science"],
|
203 |
+
"computer_security": ["other"],
|
204 |
+
"conceptual_physics": ["physics"],
|
205 |
+
"construction_project_management": ["other", "china specific"],
|
206 |
+
"economics": ["economics"],
|
207 |
+
"education": ["education"],
|
208 |
+
"elementary_chinese": ["linguistics", "china specific"],
|
209 |
+
"elementary_commonsense": ["other", "china specific"],
|
210 |
+
"elementary_information_and_technology": ["other"],
|
211 |
+
"electrical_engineering": ["engineering"],
|
212 |
+
"elementary_mathematics": ["math"],
|
213 |
+
"ethnology": ["culture", "china specific"],
|
214 |
+
"food_science": ["other"],
|
215 |
+
"genetics": ["biology"],
|
216 |
+
"global_facts": ["global"],
|
217 |
+
"high_school_biology": ["biology"],
|
218 |
+
"high_school_chemistry": ["chemistry"],
|
219 |
+
"high_school_geography": ["geography"],
|
220 |
+
"high_school_mathematics": ["math"],
|
221 |
+
"high_school_physics": ["physics"],
|
222 |
+
"high_school_politics": ["politics", "china specific"],
|
223 |
+
"human_sexuality": ["other"],
|
224 |
+
"international_law": ["law"],
|
225 |
+
"journalism": ["sociology"],
|
226 |
+
"jurisprudence": ["law"],
|
227 |
+
"legal_and_moral_basis": ["other"],
|
228 |
+
"logical": ["philosophy"],
|
229 |
+
"machine_learning": ["computer science"],
|
230 |
+
"management": ["business"],
|
231 |
+
"marketing": ["business"],
|
232 |
+
"marxist_theory": ["philosophy"],
|
233 |
+
"modern_chinese": ["linguistics", "china specific"],
|
234 |
+
"nutrition": ["other"],
|
235 |
+
"philosophy": ["philosophy"],
|
236 |
+
"professional_accounting": ["business"],
|
237 |
+
"professional_law": ["law"],
|
238 |
+
"professional_medicine": ["other"],
|
239 |
+
"professional_psychology": ["psychology"],
|
240 |
+
"public_relations": ["politics"],
|
241 |
+
"security_study": ["politics"],
|
242 |
+
"sociology": ["culture"],
|
243 |
+
"sports_science": ["other"],
|
244 |
+
"traditional_chinese_medicine": ["other", "china specific"],
|
245 |
+
"virology": ["biology"],
|
246 |
+
"world_history": ["history"],
|
247 |
+
"world_religions": ["global"],
|
248 |
+
}
|
249 |
+
|
250 |
+
categories = {
|
251 |
+
"STEM": [
|
252 |
+
"physics",
|
253 |
+
"chemistry",
|
254 |
+
"biology",
|
255 |
+
"computer science",
|
256 |
+
"math",
|
257 |
+
"engineering",
|
258 |
+
"statistics",
|
259 |
+
],
|
260 |
+
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
|
261 |
+
"Social Science": [
|
262 |
+
"linguistics",
|
263 |
+
"business",
|
264 |
+
"politics",
|
265 |
+
"culture",
|
266 |
+
"economics",
|
267 |
+
"geography",
|
268 |
+
"psychology",
|
269 |
+
"education",
|
270 |
+
"sociology",
|
271 |
+
],
|
272 |
+
"Other": ["other"],
|
273 |
+
"China specific": ["china specific"],
|
274 |
+
"Test": ["computer science"],
|
275 |
+
}
|
276 |
+
|
277 |
+
finer_categories = (
|
278 |
+
pd.Series(subcategories) # noqa # type: ignore
|
279 |
+
.explode()
|
280 |
+
.reset_index()
|
281 |
+
.set_index(0)
|
282 |
+
.groupby(0)
|
283 |
+
.agg(list)["index"]
|
284 |
+
.to_dict()
|
285 |
+
)
|
286 |
+
|
287 |
+
@classmethod
|
288 |
+
def suite(cls, chat=False):
|
289 |
+
suite = {}
|
290 |
+
for k, v in cls.categories.items():
|
291 |
+
for subject in v:
|
292 |
+
suite[k] = [
|
293 |
+
Task(
|
294 |
+
("haonan-li/cmmlu", subcategories),
|
295 |
+
metric_name=("sustech/tlem", "cmmlu"),
|
296 |
+
input_column="prompt",
|
297 |
+
label_column="Answer",
|
298 |
+
prompt=partial(cls.prompt_cmmlu, chat=chat),
|
299 |
+
)
|
300 |
+
for subcategories in cls.finer_categories[subject]
|
301 |
+
]
|
302 |
+
return suite
|
303 |
+
|
304 |
+
|
305 |
+
class MMLU:
|
306 |
+
input_column = "prompt"
|
307 |
+
label_column = "target"
|
308 |
+
|
309 |
+
@classmethod
|
310 |
+
def prompt_mmlu(cls, example, chat=False):
|
311 |
+
prefix = (
|
312 |
+
"The following is a multiple-choice question. Please choose the most suitable one among A, B, C and D as the answer to this question.\n\n"
|
313 |
+
if chat
|
314 |
+
else "Question: "
|
315 |
+
)
|
316 |
+
prompt = prefix + example["input"]
|
317 |
+
for choice in list("ABCD"):
|
318 |
+
prompt += f"\n{choice}. {example[choice]}"
|
319 |
+
|
320 |
+
prompt += "\nAnswer:"
|
321 |
+
return {"prompt": prompt}
|
322 |
+
|
323 |
+
subcategories = {
|
324 |
+
"abstract_algebra": ["math"],
|
325 |
+
"anatomy": ["health"],
|
326 |
+
"astronomy": ["physics"],
|
327 |
+
"business_ethics": ["business"],
|
328 |
+
"clinical_knowledge": ["health"],
|
329 |
+
"college_biology": ["biology"],
|
330 |
+
"college_chemistry": ["chemistry"],
|
331 |
+
"college_computer_science": ["computer science"],
|
332 |
+
"college_mathematics": ["math"],
|
333 |
+
"college_medicine": ["health"],
|
334 |
+
"college_physics": ["physics"],
|
335 |
+
"computer_security": ["computer science"],
|
336 |
+
"conceptual_physics": ["physics"],
|
337 |
+
"econometrics": ["economics"],
|
338 |
+
"electrical_engineering": ["engineering"],
|
339 |
+
"elementary_mathematics": ["math"],
|
340 |
+
"formal_logic": ["philosophy"],
|
341 |
+
"global_facts": ["other"],
|
342 |
+
"high_school_biology": ["biology"],
|
343 |
+
"high_school_chemistry": ["chemistry"],
|
344 |
+
"high_school_computer_science": ["computer science"],
|
345 |
+
"high_school_european_history": ["history"],
|
346 |
+
"high_school_geography": ["geography"],
|
347 |
+
"high_school_government_and_politics": ["politics"],
|
348 |
+
"high_school_macroeconomics": ["economics"],
|
349 |
+
"high_school_mathematics": ["math"],
|
350 |
+
"high_school_microeconomics": ["economics"],
|
351 |
+
"high_school_physics": ["physics"],
|
352 |
+
"high_school_psychology": ["psychology"],
|
353 |
+
"high_school_statistics": ["math"],
|
354 |
+
"high_school_us_history": ["history"],
|
355 |
+
"high_school_world_history": ["history"],
|
356 |
+
"human_aging": ["health"],
|
357 |
+
"human_sexuality": ["culture"],
|
358 |
+
"international_law": ["law"],
|
359 |
+
"jurisprudence": ["law"],
|
360 |
+
"logical_fallacies": ["philosophy"],
|
361 |
+
"machine_learning": ["computer science"],
|
362 |
+
"management": ["business"],
|
363 |
+
"marketing": ["business"],
|
364 |
+
"medical_genetics": ["health"],
|
365 |
+
"miscellaneous": ["other"],
|
366 |
+
"moral_disputes": ["philosophy"],
|
367 |
+
"moral_scenarios": ["philosophy"],
|
368 |
+
"nutrition": ["health"],
|
369 |
+
"philosophy": ["philosophy"],
|
370 |
+
"prehistory": ["history"],
|
371 |
+
"professional_accounting": ["other"],
|
372 |
+
"professional_law": ["law"],
|
373 |
+
"professional_medicine": ["health"],
|
374 |
+
"professional_psychology": ["psychology"],
|
375 |
+
"public_relations": ["politics"],
|
376 |
+
"security_studies": ["politics"],
|
377 |
+
"sociology": ["culture"],
|
378 |
+
"us_foreign_policy": ["politics"],
|
379 |
+
"virology": ["health"],
|
380 |
+
"world_religions": ["philosophy"],
|
381 |
+
}
|
382 |
+
|
383 |
+
categories = {
|
384 |
+
"Math": [
|
385 |
+
"math",
|
386 |
+
],
|
387 |
+
"STEM": [
|
388 |
+
"physics",
|
389 |
+
"chemistry",
|
390 |
+
"biology",
|
391 |
+
"computer science",
|
392 |
+
"math",
|
393 |
+
"engineering",
|
394 |
+
],
|
395 |
+
"humanities": ["history", "philosophy", "law"],
|
396 |
+
"social sciences": [
|
397 |
+
"politics",
|
398 |
+
"culture",
|
399 |
+
"economics",
|
400 |
+
"geography",
|
401 |
+
"psychology",
|
402 |
+
],
|
403 |
+
"Other": ["other", "business", "health"],
|
404 |
+
"All": [
|
405 |
+
"physics",
|
406 |
+
"chemistry",
|
407 |
+
"biology",
|
408 |
+
"computer science",
|
409 |
+
"math",
|
410 |
+
"engineering",
|
411 |
+
"history",
|
412 |
+
"philosophy",
|
413 |
+
"law",
|
414 |
+
"politics",
|
415 |
+
"culture",
|
416 |
+
"economics",
|
417 |
+
"geography",
|
418 |
+
"psychology",
|
419 |
+
"other",
|
420 |
+
"business",
|
421 |
+
"health",
|
422 |
+
],
|
423 |
+
"Test": ["culture"],
|
424 |
+
}
|
425 |
+
|
426 |
+
@classmethod
|
427 |
+
def suite(cls, chat=False):
|
428 |
+
finer_categories = (
|
429 |
+
pd.Series(cls.subcategories) # noqa # type: ignore
|
430 |
+
.explode()
|
431 |
+
.reset_index()
|
432 |
+
.set_index(0)
|
433 |
+
.groupby(0)
|
434 |
+
.agg(list)["index"]
|
435 |
+
.to_dict()
|
436 |
+
)
|
437 |
+
suite = {}
|
438 |
+
for k, v in cls.categories.items():
|
439 |
+
for subject in v:
|
440 |
+
suite[k] = [
|
441 |
+
Task(
|
442 |
+
("lukaemon/mmlu", subcategories),
|
443 |
+
metric_name=("sustech/tlem", "mmlu"),
|
444 |
+
input_column=cls.input_column,
|
445 |
+
label_column=cls.label_column,
|
446 |
+
prompt=partial(cls.prompt_mmlu, chat=chat),
|
447 |
+
few_shot=0 if chat else 5,
|
448 |
+
few_shot_from="validation"
|
449 |
+
)
|
450 |
+
for subcategories in finer_categories[subject]
|
451 |
+
]
|
452 |
+
return suite
|
tlem.py
CHANGED
@@ -11,7 +11,8 @@ from evaluate.evaluation_suite import EvaluationSuite
|
|
11 |
import evaluate
|
12 |
import numpy as np
|
13 |
import datasets
|
14 |
-
|
|
|
15 |
from .utils import is_equiv
|
16 |
|
17 |
# %%
|
@@ -24,56 +25,35 @@ from .utils import is_equiv
|
|
24 |
|
25 |
# TODO: Add BibTeX citation
|
26 |
_CITATION = """\
|
27 |
-
@InProceedings{huggingface:module,
|
28 |
-
title = {A great new module},
|
29 |
-
authors={huggingface, Inc.},
|
30 |
-
year={2020}
|
31 |
-
}
|
32 |
"""
|
33 |
|
34 |
# TODO: Add description of the module here
|
35 |
_DESCRIPTION = """\
|
36 |
-
A simple measurement that returns the number of elements in dataset.
|
37 |
"""
|
38 |
|
39 |
|
40 |
# TODO: Add description of the arguments of the module here
|
41 |
_KWARGS_DESCRIPTION = """
|
42 |
-
Calculates number of elements in dataset
|
43 |
-
Args:
|
44 |
-
data: list of elements.
|
45 |
-
Returns:
|
46 |
-
element_count: number of elements in dataset,
|
47 |
-
Examples:
|
48 |
-
>>> measure = evaluate.load("lvwerra/element_count")
|
49 |
-
>>> measure.compute(["a", "b", "c")
|
50 |
-
{"element_count": 3}
|
51 |
"""
|
52 |
|
53 |
# TODO: Define external resources urls if needed
|
54 |
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
55 |
|
56 |
|
57 |
-
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
58 |
class ReasoningMetric(evaluate.Metric):
|
59 |
"""TODO: Short description of my evaluation module."""
|
60 |
|
61 |
def _info(self):
|
|
|
62 |
features = datasets.Features(
|
63 |
{
|
64 |
"responses": datasets.Value("string"),
|
|
|
65 |
"references": datasets.Value("string"),
|
66 |
}
|
67 |
)
|
68 |
|
69 |
-
if self.config_name == "svamp":
|
70 |
-
features = datasets.Features(
|
71 |
-
{
|
72 |
-
"responses": datasets.Value("string"),
|
73 |
-
"references": datasets.Value("float"),
|
74 |
-
}
|
75 |
-
)
|
76 |
-
|
77 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
78 |
return evaluate.EvaluationModuleInfo(
|
79 |
# This is the description that will appear on the modules page.
|
@@ -90,38 +70,59 @@ class ReasoningMetric(evaluate.Metric):
|
|
90 |
reference_urls=["http://path.to.reference.url/new_module"],
|
91 |
)
|
92 |
|
93 |
-
def _compute(self, responses, references
|
94 |
-
results = {}
|
95 |
scores = getattr(Metrics, self.config_name)(responses, references)
|
96 |
-
|
97 |
-
results
|
98 |
-
"accuracy": acc,
|
99 |
-
"scores": scores,
|
100 |
-
}
|
101 |
-
|
102 |
-
if verbose:
|
103 |
-
results["references"] = references
|
104 |
-
results["answers"] = responses
|
105 |
-
# results["scores"] = scores
|
106 |
-
|
107 |
return results
|
108 |
|
109 |
|
110 |
class Suite(EvaluationSuite):
|
111 |
def run(
|
112 |
-
self,
|
|
|
|
|
113 |
) -> dict[str, float]:
|
114 |
self.assert_suite_nonempty()
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
super().__init__(name)
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
self.suite = [
|
127 |
Task(
|
@@ -136,4 +137,3 @@ class Suite(EvaluationSuite):
|
|
136 |
|
137 |
|
138 |
# %%
|
139 |
-
|
|
|
11 |
import evaluate
|
12 |
import numpy as np
|
13 |
import datasets
|
14 |
+
import pandas as pd
|
15 |
+
from .tasks import *
|
16 |
from .utils import is_equiv
|
17 |
|
18 |
# %%
|
|
|
25 |
|
26 |
# TODO: Add BibTeX citation
|
27 |
_CITATION = """\
|
|
|
|
|
|
|
|
|
|
|
28 |
"""
|
29 |
|
30 |
# TODO: Add description of the module here
|
31 |
_DESCRIPTION = """\
|
|
|
32 |
"""
|
33 |
|
34 |
|
35 |
# TODO: Add description of the arguments of the module here
|
36 |
_KWARGS_DESCRIPTION = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
"""
|
38 |
|
39 |
# TODO: Define external resources urls if needed
|
40 |
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
|
41 |
|
42 |
|
43 |
+
# @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
44 |
class ReasoningMetric(evaluate.Metric):
|
45 |
"""TODO: Short description of my evaluation module."""
|
46 |
|
47 |
def _info(self):
|
48 |
+
# if self.config_name in ["cmmlu"]:
|
49 |
features = datasets.Features(
|
50 |
{
|
51 |
"responses": datasets.Value("string"),
|
52 |
+
# "responses": datasets.Sequence(datasets.Value("float")),
|
53 |
"references": datasets.Value("string"),
|
54 |
}
|
55 |
)
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# TODO: Specifies the evaluate.EvaluationModuleInfo object
|
58 |
return evaluate.EvaluationModuleInfo(
|
59 |
# This is the description that will appear on the modules page.
|
|
|
70 |
reference_urls=["http://path.to.reference.url/new_module"],
|
71 |
)
|
72 |
|
73 |
+
def _compute(self, responses, references):
|
|
|
74 |
scores = getattr(Metrics, self.config_name)(responses, references)
|
75 |
+
results = {"Accuracy": np.nanmean(scores)}
|
76 |
+
logging.info(results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
return results
|
78 |
|
79 |
|
80 |
class Suite(EvaluationSuite):
|
81 |
def run(
|
82 |
+
self,
|
83 |
+
model_or_pipeline: Any,
|
84 |
+
name="tlem",
|
85 |
) -> dict[str, float]:
|
86 |
self.assert_suite_nonempty()
|
87 |
|
88 |
+
def run_tasks(tasks):
|
89 |
+
for task in tqdm(tasks):
|
90 |
+
if task.name not in self.cached_result:
|
91 |
+
self.cached_result[task.name] = task.run(model_or_pipeline)
|
92 |
+
results = [self.cached_result[task.name] for task in tasks]
|
93 |
+
return pd.DataFrame(results).mean().to_dict()
|
94 |
+
|
95 |
+
if isinstance(self.suite, dict):
|
96 |
+
for category, tasks in tqdm(self.suite.items()):
|
97 |
+
logging.warning(f"Combined results: {category}:{run_tasks(tasks)}")
|
98 |
+
else:
|
99 |
+
logging.warning(f"Combined results: {run_tasks(self.suite)}")
|
100 |
+
|
101 |
+
return self.cached_result
|
102 |
+
|
103 |
+
def add(self, name):
|
104 |
+
chat = False
|
105 |
+
match name:
|
106 |
+
case _ if "chat" in name:
|
107 |
+
chat = True
|
108 |
+
match name:
|
109 |
+
case _ if name.startswith("mmlu"):
|
110 |
+
suite = MMLU.suite(chat=chat)
|
111 |
+
case _ if name.startswith("cmmlu"):
|
112 |
+
suite = CMMLU.suite(chat=chat)
|
113 |
+
match name:
|
114 |
+
case _ if "test" in name:
|
115 |
+
suite = suite["Test"]
|
116 |
+
|
117 |
+
self.suite = suite
|
118 |
+
|
119 |
+
def __init__(self, name="tlem"):
|
120 |
super().__init__(name)
|
121 |
+
self.cached_result = {}
|
122 |
+
|
123 |
+
match self.name:
|
124 |
+
case "cmmlu":
|
125 |
+
pass
|
126 |
|
127 |
self.suite = [
|
128 |
Task(
|
|
|
137 |
|
138 |
|
139 |
# %%
|
|
utils.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import logging
|
2 |
import re
|
|
|
|
|
3 |
|
4 |
NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))"
|
5 |
NUMERIC_IN_ZH = (
|
@@ -7,17 +9,43 @@ NUMERIC_IN_ZH = (
|
|
7 |
)
|
8 |
|
9 |
|
10 |
-
def
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
def standardize(ans):
|
18 |
-
return ans if len(ans) == 1 else ans[1]
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
def extract_numeric(string, pattern=NUMERIC_IN_EN) -> str:
|
|
|
1 |
import logging
|
2 |
import re
|
3 |
+
import numpy as np
|
4 |
+
from typing import Any
|
5 |
|
6 |
NUMERIC_IN_EN = r"(?:[\s=+-/<>($:\.\*\\])(?=\S)((?:0|(?:\d{1,3}(?:,\d{3})+(?=\D|$))|(?:\d+))(?:\.\d+)?%?)(?:(?![^\s=+-/>)$:\.\*\\])|(?=, ))"
|
7 |
NUMERIC_IN_ZH = (
|
|
|
9 |
)
|
10 |
|
11 |
|
12 |
+
def extract_choice(gen):
|
13 |
+
# answer is A | choice is A | choose A
|
14 |
+
res = re.search(
|
15 |
+
r"(?:(?:[Cc]hoose)|(?:(?:[Aa]nswer|[Cc]hoice)(?![^ABCD]{0,20}?(?:n't|not))[^ABCD]{0,10}?\b(?:|is|:|be))\b)[^ABCD]{0,20}?\b(A|B|C|D)\b",
|
16 |
+
gen,
|
17 |
+
)
|
18 |
+
|
19 |
+
# A is correct | A is right
|
20 |
+
if res is None:
|
21 |
+
res = re.search(
|
22 |
+
r"\b(A|B|C|D)\b(?![^ABCD]{0,8}?(?:n't|not)[^ABCD]{0,5}?(?:correct|right))[^ABCD]{0,10}?\b(?:correct|right)\b",
|
23 |
+
gen,
|
24 |
+
)
|
25 |
+
|
26 |
+
# straight answer: A
|
27 |
+
if res is None:
|
28 |
+
res = re.search(r"^(A|B|C|D)(?:\.|,|:|$)", gen)
|
29 |
+
|
30 |
+
# simply extract the first appearred letter
|
31 |
+
if res is None:
|
32 |
+
res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
|
33 |
+
|
34 |
+
if res is None:
|
35 |
+
res = "A"
|
36 |
+
|
37 |
+
if isinstance(res, str):
|
38 |
+
return res
|
39 |
+
|
40 |
+
return res.group(1)
|
41 |
|
|
|
|
|
42 |
|
43 |
+
def decode_choice(responses: list[Any]):
|
44 |
+
num_choices = responses[0].shape[0]
|
45 |
+
choices = np.argmax(np.asarray(responses), axis=1)
|
46 |
+
responses = np.array(list("ABCDEFGHIJKL"[:num_choices]))[choices]
|
47 |
+
# return (responses == np.array(references)).mean()
|
48 |
+
return responses
|
49 |
|
50 |
|
51 |
def extract_numeric(string, pattern=NUMERIC_IN_EN) -> str:
|