DarrenChensformer commited on
Commit
ca4743e
1 Parent(s): 8f7a372

Add default value of valid_labels

Browse files
Files changed (1) hide show
  1. action_generation.py +19 -1
action_generation.py CHANGED
@@ -56,8 +56,21 @@ Examples:
56
  # TODO: Define external resources urls if needed
57
  BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  class BaseEvaluater:
60
  eps = 1e-8
 
61
 
62
  def __call__(self, preds, labels):
63
  return self._compute(preds, labels)
@@ -78,10 +91,15 @@ class BaseEvaluater:
78
  "recall": round(recall, 4),
79
  "f1": round(f1, 4)
80
  }
 
 
 
 
81
 
82
  class ClassEvaluater(BaseEvaluater):
83
  def __init__(self, valid_labels=None):
84
  self.valid_labels = valid_labels
 
85
 
86
  def __call__(self, preds, labels):
87
  preds = map(self.extract_class, preds)
@@ -92,7 +110,6 @@ class ClassEvaluater(BaseEvaluater):
92
  return self._compute(preds, labels)
93
 
94
  def extract_valid(self, tags):
95
- # TODO: if valid_labels is None:
96
  tags = list(filter(lambda tag: tag in self.valid_labels, tags))
97
  return tags
98
 
@@ -123,6 +140,7 @@ class ClassEvaluater(BaseEvaluater):
123
  class PhraseEvaluater(BaseEvaluater):
124
  def __init__(self, valid_labels=None):
125
  self.valid_labels = valid_labels
 
126
 
127
  def __call__(self, preds, labels):
128
  preds = map(self.extract_phrase, preds)
 
56
  # TODO: Define external resources urls if needed
57
  BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
 
59
+ VALID_LABELS = [
60
+ "/開箱",
61
+ "/教學",
62
+ "/表達",
63
+ "/分享/外部資訊",
64
+ "/分享/個人資訊",
65
+ "/推薦/產品",
66
+ "/推薦/服務",
67
+ "/推薦/其他",
68
+ ""
69
+ ]
70
+
71
  class BaseEvaluater:
72
  eps = 1e-8
73
+ valid_labels = None
74
 
75
  def __call__(self, preds, labels):
76
  return self._compute(preds, labels)
 
91
  "recall": round(recall, 4),
92
  "f1": round(f1, 4)
93
  }
94
+
95
+ def _init_valid_labels(self):
96
+ if self.valid_labels is None:
97
+ self.valid_labels = VALID_LABELS
98
 
99
  class ClassEvaluater(BaseEvaluater):
100
  def __init__(self, valid_labels=None):
101
  self.valid_labels = valid_labels
102
+ self._init_valid_labels()
103
 
104
  def __call__(self, preds, labels):
105
  preds = map(self.extract_class, preds)
 
110
  return self._compute(preds, labels)
111
 
112
  def extract_valid(self, tags):
 
113
  tags = list(filter(lambda tag: tag in self.valid_labels, tags))
114
  return tags
115
 
 
140
  class PhraseEvaluater(BaseEvaluater):
141
  def __init__(self, valid_labels=None):
142
  self.valid_labels = valid_labels
143
+ self._init_valid_labels()
144
 
145
  def __call__(self, preds, labels):
146
  preds = map(self.extract_phrase, preds)