from collections import defaultdict from concurrent.futures import ProcessPoolExecutor from typing import List, Optional from PIL import Image import numpy as np from surya.detection import batch_detection from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes from surya.schema import LayoutResult, LayoutBox, TextDetectionResult from surya.settings import settings def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: logits = np.stack(heatmaps, axis=0) vertical_line_bboxes = detection_result.vertical_lines line_bboxes = detection_result.bboxes # Scale back to processor size for line in vertical_line_bboxes: line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape))) for line in line_bboxes: line.rescale(orig_size, list(reversed(heatmaps[0].shape))) for bbox in vertical_line_bboxes: # Give some width to the vertical lines vert_bbox = list(bbox.bbox) vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width) logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0 # zero out where the column lines are logits[:, logits[0] >= .5] = 0 # zero out where blanks are # Zero out where other segments are for i in range(logits.shape[0]): logits[i, segment_assignment != i] = 0 detected_boxes = [] for heatmap_idx in range(1, len(id2label)): # Skip the blank class heatmap = logits[heatmap_idx] if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: continue bboxes = get_detected_boxes(heatmap) bboxes = [bbox for bbox in bboxes if bbox.area > 25] for bb in bboxes: bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1]) for bbox in bboxes: detected_boxes.append(LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=1)) detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True) # Expand bbox to cover intersecting lines box_lines = defaultdict(list) used_lines = set() # We try 2 rounds of identifying the correct lines to snap to # First round is majority intersection, second lowers the threshold for thresh in [.5, .4]: for bbox_idx, bbox in enumerate(detected_boxes): for line_idx, line_bbox in enumerate(line_bboxes): if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines: box_lines[bbox_idx].append(line_bbox.bbox) used_lines.add(line_idx) new_boxes = [] for bbox_idx, bbox in enumerate(detected_boxes): if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures continue # Skip if we didn't find any lines to snap to, except for Pictures and Formulas if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]: continue covered_lines = box_lines[bbox_idx] # Snap non-picture layout boxes to correct text boundaries if len(covered_lines) > 0 and bbox.label not in ["Picture"]: min_x = min([line[0] for line in covered_lines]) min_y = min([line[1] for line in covered_lines]) max_x = max([line[2] for line in covered_lines]) max_y = max([line[3] for line in covered_lines]) # Tables and formulas can contain text, but text isn't the whole area if bbox.label in ["Table", "Formula"]: min_x_box = min([b[0] for b in bbox.polygon]) min_y_box = min([b[1] for b in bbox.polygon]) max_x_box = max([b[0] for b in bbox.polygon]) max_y_box = max([b[1] for b in bbox.polygon]) min_x = min(min_x, min_x_box) min_y = min(min_y, min_y_box) max_x = max(max_x, max_x_box) max_y = max(max_y, max_y_box) bbox.polygon = [ [min_x, min_y], [max_x, min_y], [max_x, max_y], [min_x, max_y] ] if bbox_idx in box_lines and bbox.label in ["Picture"]: bbox.label = "Figure" new_boxes.append(bbox) # Merge tables together (sometimes one column is detected as a separate table) mergeable_types = ["Table", "Picture", "Figure"] for ftype in mergeable_types: to_remove = set() for bbox_idx, bbox in enumerate(new_boxes): if bbox.label != ftype or bbox_idx in to_remove: continue for bbox_idx2, bbox2 in enumerate(new_boxes): if bbox2.label != ftype or bbox_idx2 in to_remove or bbox_idx == bbox_idx2: continue if bbox.intersection_pct(bbox2, x_margin=.25) > .1: bbox.merge(bbox2) to_remove.add(bbox_idx2) new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove] # Ensure we account for all text lines in the layout unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines] for bbox in unused_lines: new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5)) for bbox in new_boxes: bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size) detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16] # Remove bboxes contained inside others, unless they're captions contained_bbox = [] for i, bbox in enumerate(detected_boxes): for j, bbox2 in enumerate(detected_boxes): if i == j: continue if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]: contained_bbox.append(j) detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox] return detected_boxes def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]: bboxes = [] for i in range(1, len(id2label)): # Skip the blank class heatmap = heatmaps[i] assert heatmap.shape == segment_assignment.shape heatmap[segment_assignment != i] = 0 # zero out where another segment is # Skip processing empty labels if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: continue bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size) for bb in bbox: bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i])) bboxes = keep_largest_boxes(bboxes) return bboxes def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult: logits = np.stack(heatmaps, axis=0) segment_assignment = logits.argmax(axis=0) if detection_results is not None: bboxes = get_regions_from_detection_result(detection_results, heatmaps, orig_size, id2label, segment_assignment) else: bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment) segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8)) result = LayoutResult( bboxes=bboxes, segmentation_map=segmentation_img, heatmaps=heatmaps, image_bbox=[0, 0, orig_size[0], orig_size[1]] ) return result def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]: layout_generator = batch_detection(images, model, processor, batch_size=batch_size) id2label = model.config.id2label results = [] max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH if parallelize: with ProcessPoolExecutor(max_workers=max_workers) as executor: img_idx = 0 for preds, orig_sizes in layout_generator: futures = [] for pred, orig_size in zip(preds, orig_sizes): future = executor.submit( parallel_get_regions, pred, orig_size, id2label, detection_results[img_idx] if detection_results else None ) futures.append(future) img_idx += 1 for future in futures: results.append(future.result()) else: img_idx = 0 for preds, orig_sizes in layout_generator: for pred, orig_size in zip(preds, orig_sizes): results.append(parallel_get_regions( pred, orig_size, id2label, detection_results[img_idx] if detection_results else None )) img_idx += 1 return results