nam194 commited on
Commit
dc354d3
1 Parent(s): 44a5c94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -104,10 +104,12 @@ def pred_resume(pdf_path) -> dict:
104
  if text.replace(" ","") != "":
105
  bboxes.append(normalize_bbox([xmin, ymin, xmax, ymax], image.size))
106
  words.append(decontracted(text))
 
107
  fake_label = ["O"] * len(words)
108
  encoding = processor(image, words, boxes=bboxes, word_labels=fake_label, truncation=True, stride=256,
109
  padding="max_length", max_length=512, return_overflowing_tokens=True, return_offsets_mapping=True)
110
  labels = encoding["labels"]
 
111
  offset_mapping = encoding.pop('offset_mapping')
112
  overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
113
  encoding = {k: torch.tensor(v) for k,v in encoding.items() if k != "labels"}
@@ -128,12 +130,16 @@ def pred_resume(pdf_path) -> dict:
128
  if i>0:
129
  labels[i] = labels[i][256:]
130
  predictions[i] = predictions[i][256:]
 
131
  predictions = [j for i in predictions for j in i]
 
132
  labels = [j for i in labels for j in i]
133
  true_predictions = [id2label[pred] for pred, label in zip(predictions, labels) if label != -100]
134
- for i, pred in enumerate(true_predictions):
 
135
  if pred in key_list:
136
- result[pred].append(words[i])
 
137
  return str(result)
138
  def norm(result: str) -> str:
139
  result = ast.literal_eval(result)
 
104
  if text.replace(" ","") != "":
105
  bboxes.append(normalize_bbox([xmin, ymin, xmax, ymax], image.size))
106
  words.append(decontracted(text))
107
+ text_reverse = {str(bboxes[i]): words[i] for i,_ in enumerate(words)}
108
  fake_label = ["O"] * len(words)
109
  encoding = processor(image, words, boxes=bboxes, word_labels=fake_label, truncation=True, stride=256,
110
  padding="max_length", max_length=512, return_overflowing_tokens=True, return_offsets_mapping=True)
111
  labels = encoding["labels"]
112
+ key_box = encoding["bbox"]
113
  offset_mapping = encoding.pop('offset_mapping')
114
  overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping')
115
  encoding = {k: torch.tensor(v) for k,v in encoding.items() if k != "labels"}
 
130
  if i>0:
131
  labels[i] = labels[i][256:]
132
  predictions[i] = predictions[i][256:]
133
+ key_box[i] = key_box[i][256:]
134
  predictions = [j for i in predictions for j in i]
135
+ key_box = [j for i in key_box for j in i]
136
  labels = [j for i in labels for j in i]
137
  true_predictions = [id2label[pred] for pred, label in zip(predictions, labels) if label != -100]
138
+ key_box = [box for box, label in zip(key_box, labels) if label != -100]
139
+ for box, pred in zip(key_box, true_predictions):
140
  if pred in key_list:
141
+ result[pred].append(text_reverse[str(box)])
142
+ result = {k: list(set(v)) for k, v in result.items()}
143
  return str(result)
144
  def norm(result: str) -> str:
145
  result = ast.literal_eval(result)