Ikala-allen commited on
Commit
a84678a
1 Parent(s): b519cf9

add evaluation

Browse files
Files changed (4) hide show
  1. .gitignore +171 -0
  2. app.py +10 -0
  3. custom_metric/custom_metric.py +203 -0
  4. custom_metric/metric.yml +10 -0
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.conll
6
+ *.pt
7
+ *.onnx
8
+ # C extensions
9
+ *.so
10
+ *.csv
11
+ *.json
12
+ *.joblib
13
+ *.ipynb
14
+ *.pkl
15
+ # Distribution / packaging
16
+ .Python
17
+ build/
18
+ develop-eggs/
19
+ dist/
20
+ downloads/
21
+ eggs/
22
+ .eggs/
23
+ lib/
24
+ lib64/
25
+ parts/
26
+ sdist/
27
+ var/
28
+ wheels/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+ cover/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+ db.sqlite3-journal
69
+
70
+ # Flask stuff:
71
+ instance/
72
+ .webassets-cache
73
+
74
+ # Scrapy stuff:
75
+ .scrapy
76
+
77
+ # Sphinx documentation
78
+ docs/_build/
79
+
80
+ # PyBuilder
81
+ .pybuilder/
82
+ target/
83
+
84
+ # Jupyter Notebook
85
+ .ipynb_checkpoints
86
+
87
+ # IPython
88
+ profile_default/
89
+ ipython_config.py
90
+
91
+ # pyenv
92
+ # For a library or package, you might want to ignore these files since the code is
93
+ # intended to run in multiple environments; otherwise, check them in:
94
+ .python-version
95
+ env/
96
+ venv
97
+
98
+ # pipenv
99
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
101
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
102
+ # install all needed dependencies.
103
+ Pipfile.lock
104
+
105
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106
+ __pypackages__/
107
+
108
+ # Celery stuff
109
+ celerybeat-schedule
110
+ celerybeat.pid
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+ split.py
139
+
140
+ # Pyre type checker
141
+ .pyre/
142
+
143
+ # pytype static type analyzer
144
+ .pytype/
145
+
146
+ # Cython debug symbols
147
+ cython_debug/
148
+
149
+ # Logs
150
+ *.log.*
151
+ */logs
152
+ */var/run
153
+ */*/*/*/run/*
154
+ */*/*/*/logs/*
155
+
156
+ # OS generated files
157
+ .DS_Store*
158
+ ehthumbs.db
159
+ Icon?
160
+ Thumbs.db
161
+
162
+ # Editor Files
163
+ *~
164
+ *.swp
165
+ cli/meta
166
+ # IDE Files
167
+ .vscode/
168
+
169
+ # model file
170
+ *.pkl
171
+ catboost_info/
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ from evaluate.utils import launch_gradio_widget
3
+
4
+ # Define the path to your custom metric directory
5
+ metric_path = "./custom_metric"
6
+
7
+
8
+ module = evaluate.load(metric_path)
9
+ launch_gradio_widget(module)
10
+
custom_metric/custom_metric.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import datasets
3
+ import numpy as np
4
+
5
+ _CITATION = """\
6
+ @InProceedings{huggingface:module,
7
+ title = {A great new module},
8
+ authors={huggingface, Inc.},
9
+ year={2020}
10
+ }
11
+ """
12
+
13
+ # TODO: Add description of the module here
14
+ _DESCRIPTION = """\
15
+ This new module is designed to solve this great ML task and is crafted with a lot of care.
16
+ """
17
+
18
+
19
+ # TODO: Add description of the arguments of the module here
20
+ _KWARGS_DESCRIPTION = """
21
+ Calculates how good are predictions given some references, using certain scores
22
+ Args:
23
+ predictions: list of predictions to score. Each predictions
24
+ should be a string with tokens separated by spaces.
25
+ references: list of reference for each prediction. Each
26
+ reference should be a string with tokens separated by spaces.
27
+ Returns:
28
+ accuracy: description of the first score,
29
+ another_score: description of the second score,
30
+ Examples:
31
+ Examples should be written in doctest format, and should illustrate how
32
+ to use the function.
33
+ >>> my_new_module = evaluate.load("my_new_module")
34
+ >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
35
+ >>> print(results)
36
+ {'accuracy': 1.0}
37
+ """
38
+
39
+ # TODO: Define external resources urls if needed
40
+ BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
41
+
42
+
43
+ def convert_format(data:list):
44
+ """
45
+ Args:
46
+ data (list) : list of dictionaries with different entity elements
47
+ e.g
48
+ [
49
+ {'head': ['phipigments', 'tinadaviespigments'...],
50
+ 'head_type': ['product', 'brand'...],
51
+ 'type': ['sell', 'sell'...],
52
+ 'tail': ['國際認證之色乳', '國際認證之色乳'...],
53
+ 'tail_type': ['product', 'product'...]},
54
+
55
+ {'head': ['SABONTAIWAN', 'SNTAIWAN'...],
56
+ 'head_type': ['brand', 'brand'...],
57
+ 'type': ['sell', 'sell'...],
58
+ 'tail': ['大馬士革玫瑰有機光燦系列', '大馬士革玫瑰有機光燦系列'...],
59
+ 'tail_type': ['product', 'product'...]}
60
+ ...
61
+ ]
62
+ """
63
+ predictions = []
64
+ for item in data:
65
+ prediction_group = []
66
+ for i in range(len(item['head'])):
67
+ prediction = {
68
+ 'head': item['head'][i],
69
+ 'head_type': item['head_type'][i],
70
+ 'type': item['type'][i],
71
+ 'tail': item['tail'][i],
72
+ 'tail_type': item['tail_type'][i],
73
+ }
74
+ prediction_group.append(prediction)
75
+ predictions.append(prediction_group)
76
+ return predictions
77
+
78
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
79
+ class relation_extraction(evaluate.Metric):
80
+ """TODO: Short description of my evaluation module."""
81
+
82
+ def _info(self):
83
+ # TODO: Specifies the evaluate.EvaluationModuleInfo object
84
+ return evaluate.MetricInfo(
85
+ # This is the description that will appear on the modules page.
86
+ module_type="metric",
87
+ description=_DESCRIPTION,
88
+ citation=_CITATION,
89
+ inputs_description=_KWARGS_DESCRIPTION,
90
+ # This defines the format of each prediction and reference
91
+ features=datasets.Features({
92
+ 'predictions': datasets.Sequence({
93
+ "head": datasets.Value("string"),
94
+ "head_type": datasets.Value("string"),
95
+ "type": datasets.Value("string"),
96
+ "tail": datasets.Value("string"),
97
+ "tail_type": datasets.Value("string"),
98
+ }),
99
+ 'references': datasets.Sequence({
100
+ "head": datasets.Value("string"),
101
+ "head_type": datasets.Value("string"),
102
+ "type": datasets.Value("string"),
103
+ "tail": datasets.Value("string"),
104
+ "tail_type": datasets.Value("string"),
105
+ }),
106
+ }),
107
+ # Homepage of the module for documentation
108
+ homepage="http://module.homepage",
109
+ # Additional links to the codebase or references
110
+ codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
111
+ reference_urls=["http://path.to.reference.url/new_module"]
112
+ )
113
+
114
+ def _download_and_prepare(self, dl_manager):
115
+ """Optional: download external resources useful to compute the scores"""
116
+ # TODO: Download external resources if needed
117
+ pass
118
+
119
+ def _compute(self, predictions, references, mode="strict", relation_types=[]):
120
+ """Returns the scores"""
121
+ # TODO: Compute the different scores of the module
122
+ print(predictions)
123
+ predictions = convert_format(predictions)
124
+ references = convert_format(references)
125
+ print(predictions)
126
+ assert mode in ["strict", "boundaries"]
127
+
128
+ # construct relation_types from ground truth if not given
129
+ if len(relation_types) == 0:
130
+ for triplets in references:
131
+ for triplet in triplets:
132
+ relation = triplet["type"]
133
+ if relation not in relation_types:
134
+ relation_types.append(relation)
135
+
136
+ scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
137
+
138
+ # Count GT relations and Predicted relations
139
+ n_sents = len(references)
140
+ n_rels = sum([len([rel for rel in sent]) for sent in references])
141
+ n_found = sum([len([rel for rel in sent]) for sent in predictions])
142
+
143
+ # Count TP, FP and FN per type
144
+ for pred_sent, gt_sent in zip(predictions, references):
145
+ for rel_type in relation_types:
146
+ # strict mode takes argument types into account
147
+ if mode == "strict":
148
+ pred_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in pred_sent if
149
+ rel["type"] == rel_type}
150
+ gt_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in gt_sent if
151
+ rel["type"] == rel_type}
152
+
153
+ # boundaries mode only takes argument spans into account
154
+ elif mode == "boundaries":
155
+ pred_rels = {(rel["head"], rel["tail"]) for rel in pred_sent if rel["type"] == rel_type}
156
+ gt_rels = {(rel["head"], rel["tail"]) for rel in gt_sent if rel["type"] == rel_type}
157
+
158
+ scores[rel_type]["tp"] += len(pred_rels & gt_rels)
159
+ scores[rel_type]["fp"] += len(pred_rels - gt_rels)
160
+ scores[rel_type]["fn"] += len(gt_rels - pred_rels)
161
+
162
+ # Compute per entity Precision / Recall / F1
163
+ for rel_type in scores.keys():
164
+ if scores[rel_type]["tp"]:
165
+ scores[rel_type]["p"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fp"] + scores[rel_type]["tp"])
166
+ scores[rel_type]["r"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fn"] + scores[rel_type]["tp"])
167
+ else:
168
+ scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
169
+
170
+ if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
171
+ scores[rel_type]["f1"] = 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / (
172
+ scores[rel_type]["p"] + scores[rel_type]["r"])
173
+ else:
174
+ scores[rel_type]["f1"] = 0
175
+
176
+ # Compute micro F1 Scores
177
+ tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
178
+ fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
179
+ fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
180
+
181
+
182
+ if tp:
183
+ precision = 100 * tp / (tp + fp)
184
+ recall = 100 * tp / (tp + fn)
185
+ f1 = 2 * precision * recall / (precision + recall)
186
+
187
+ else:
188
+ precision, recall, f1 = 0, 0, 0
189
+
190
+ scores["ALL"]["p"] = precision
191
+ scores["ALL"]["r"] = recall
192
+ scores["ALL"]["f1"] = f1
193
+ scores["ALL"]["tp"] = tp
194
+ scores["ALL"]["fp"] = fp
195
+ scores["ALL"]["fn"] = fn
196
+
197
+
198
+ # Compute Macro F1 Scores
199
+ scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
200
+ scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
201
+ scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])
202
+
203
+ return scores
custom_metric/metric.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ metric_name: custom_relation_extraction
2
+ description: Custom Relation Extraction Metric
3
+ inputs:
4
+ - name: predictions
5
+ type: list
6
+ required: true
7
+ - name: references
8
+ type: list
9
+ required: true
10
+ compute_function: custom_metric.relation_extraction.compute