GGroenendaal commited on
Commit
081d5bf
1 Parent(s): 8fe5a80

move preprocessing to dependency injection

Browse files
Files changed (2) hide show
  1. base_model/evaluate.py +18 -20
  2. base_model/string_utils.py +20 -0
base_model/evaluate.py CHANGED
@@ -1,29 +1,27 @@
1
- def normalize_text(s: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
2
  """Preprocesses the sentence string by normalizing.
3
 
4
  Args:
5
  s (str): the sentence
6
 
7
  Returns:
8
- string: normalized sentence
9
  """
10
- import string, re
11
-
12
- def remove_articles(text):
13
- regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
14
- return re.sub(regex, " ", text)
15
-
16
- def white_space_fix(text):
17
- return " ".join(text.split())
18
-
19
- def remove_punc(text):
20
- exclude = set(string.punctuation)
21
- return "".join(ch for ch in text if ch not in exclude)
22
 
23
- def lower(text):
24
- return text.lower()
25
 
26
- return white_space_fix(remove_articles(remove_punc(lower(s))))
27
 
28
 
29
  def compute_exact_match(prediction: str, answer: str) -> int:
@@ -36,7 +34,7 @@ def compute_exact_match(prediction: str, answer: str) -> int:
36
  Returns:
37
  int: 1 for exact match, 0 for not
38
  """
39
- return int(normalize_text(prediction) == normalize_text(answer))
40
 
41
 
42
  def compute_f1(prediction: str, answer: str) -> float:
@@ -49,8 +47,8 @@ def compute_f1(prediction: str, answer: str) -> float:
49
  Returns:
50
  boolean: the f1 score
51
  """
52
- pred_tokens = normalize_text(prediction).split()
53
- answer_tokens = normalize_text(answer).split()
54
 
55
  if len(pred_tokens) == 0 or len(answer_tokens) == 0:
56
  return int(pred_tokens == answer_tokens)
 
1
+ from typing import Callable, List
2
+
3
+ from base_model.string_utils import lower, remove_articles, remove_punc, white_space_fix
4
+
5
+
6
+ def normalize_text(inp: str, functions: List[Callable[[str], str]]):
7
+ for fun in functions:
8
+ inp = fun(inp)
9
+ return inp
10
+
11
+
12
+ def normalize_text_default(inp: str) -> str:
13
  """Preprocesses the sentence string by normalizing.
14
 
15
  Args:
16
  s (str): the sentence
17
 
18
  Returns:
19
+ string: normalized with default parames
20
  """
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ steps = [remove_articles, white_space_fix, remove_punc, lower]
 
23
 
24
+ return normalize_text(inp, steps)
25
 
26
 
27
  def compute_exact_match(prediction: str, answer: str) -> int:
 
34
  Returns:
35
  int: 1 for exact match, 0 for not
36
  """
37
+ return int(normalize_text_default(prediction) == normalize_text_default(answer))
38
 
39
 
40
  def compute_f1(prediction: str, answer: str) -> float:
 
47
  Returns:
48
  boolean: the f1 score
49
  """
50
+ pred_tokens = normalize_text_default(prediction).split()
51
+ answer_tokens = normalize_text_default(answer).split()
52
 
53
  if len(pred_tokens) == 0 or len(answer_tokens) == 0:
54
  return int(pred_tokens == answer_tokens)
base_model/string_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+
4
+
5
+ def remove_articles(text):
6
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
7
+ return re.sub(regex, " ", text)
8
+
9
+
10
+ def white_space_fix(text):
11
+ return " ".join(text.split())
12
+
13
+
14
+ def remove_punc(text):
15
+ exclude = set(string.punctuation)
16
+ return "".join(ch for ch in text if ch not in exclude)
17
+
18
+
19
+ def lower(text):
20
+ return text.lower()