Spaces:
Running
Running
import copy | |
from typing import List, Tuple, Any, Optional | |
from pydantic import BaseModel, field_validator, computed_field | |
from surya.postprocessing.util import rescale_bbox | |
class PolygonBox(BaseModel): | |
polygon: List[List[float]] | |
confidence: Optional[float] = None | |
def check_elements(cls, v: List[List[float]]) -> List[List[float]]: | |
if len(v) != 4: | |
raise ValueError('corner must have 4 elements') | |
for corner in v: | |
if len(corner) != 2: | |
raise ValueError('corner must have 2 elements') | |
return v | |
def height(self): | |
return self.bbox[3] - self.bbox[1] | |
def width(self): | |
return self.bbox[2] - self.bbox[0] | |
def area(self): | |
return self.width * self.height | |
def bbox(self) -> List[float]: | |
box = [self.polygon[0][0], self.polygon[0][1], self.polygon[1][0], self.polygon[2][1]] | |
if box[0] > box[2]: | |
box[0], box[2] = box[2], box[0] | |
if box[1] > box[3]: | |
box[1], box[3] = box[3], box[1] | |
return box | |
def rescale(self, processor_size, image_size): | |
# Point is in x, y format | |
page_width, page_height = processor_size | |
img_width, img_height = image_size | |
width_scaler = img_width / page_width | |
height_scaler = img_height / page_height | |
new_corners = copy.deepcopy(self.polygon) | |
for corner in new_corners: | |
corner[0] = int(corner[0] * width_scaler) | |
corner[1] = int(corner[1] * height_scaler) | |
self.polygon = new_corners | |
def fit_to_bounds(self, bounds): | |
new_corners = copy.deepcopy(self.polygon) | |
for corner in new_corners: | |
corner[0] = max(min(corner[0], bounds[2]), bounds[0]) | |
corner[1] = max(min(corner[1], bounds[3]), bounds[1]) | |
self.polygon = new_corners | |
def merge(self, other): | |
x1 = min(self.bbox[0], other.bbox[0]) | |
y1 = min(self.bbox[1], other.bbox[1]) | |
x2 = max(self.bbox[2], other.bbox[2]) | |
y2 = max(self.bbox[3], other.bbox[3]) | |
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] | |
def intersection_area(self, other, x_margin=0, y_margin=0): | |
x_overlap = max(0, min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin)) | |
y_overlap = max(0, min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin)) | |
return x_overlap * y_overlap | |
def intersection_pct(self, other, x_margin=0, y_margin=0): | |
assert 0 <= x_margin <= 1 | |
assert 0 <= y_margin <= 1 | |
if self.area == 0: | |
return 0 | |
if x_margin: | |
x_margin = int(min(self.width, other.width) * x_margin) | |
if y_margin: | |
y_margin = int(min(self.height, other.height) * y_margin) | |
intersection = self.intersection_area(other, x_margin, y_margin) | |
return intersection / self.area | |
class Bbox(BaseModel): | |
bbox: List[float] | |
def check_4_elements(cls, v: List[float]) -> List[float]: | |
if len(v) != 4: | |
raise ValueError('bbox must have 4 elements') | |
return v | |
def rescale_bbox(self, orig_size, new_size): | |
self.bbox = rescale_bbox(self.bbox, orig_size, new_size) | |
def round_bbox(self, divisor): | |
self.bbox = [x // divisor * divisor for x in self.bbox] | |
def height(self): | |
return self.bbox[3] - self.bbox[1] | |
def width(self): | |
return self.bbox[2] - self.bbox[0] | |
def area(self): | |
return self.width * self.height | |
def polygon(self): | |
return [[self.bbox[0], self.bbox[1]], [self.bbox[2], self.bbox[1]], [self.bbox[2], self.bbox[3]], [self.bbox[0], self.bbox[3]]] | |
def center(self): | |
return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2] | |
def intersection_pct(self, other): | |
if self.area == 0: | |
return 0 | |
x_overlap = max(0, min(self.bbox[2], other.bbox[2]) - max(self.bbox[0], other.bbox[0])) | |
y_overlap = max(0, min(self.bbox[3], other.bbox[3]) - max(self.bbox[1], other.bbox[1])) | |
intersection = x_overlap * y_overlap | |
return intersection / self.area | |
class LayoutBox(PolygonBox): | |
label: str | |
class OrderBox(Bbox): | |
position: int | |
class ColumnLine(Bbox): | |
vertical: bool | |
horizontal: bool | |
class TextLine(PolygonBox): | |
text: str | |
confidence: Optional[float] = None | |
class OCRResult(BaseModel): | |
text_lines: List[TextLine] | |
languages: List[str] | None = None | |
image_bbox: List[float] | |
class TextDetectionResult(BaseModel): | |
bboxes: List[PolygonBox] | |
vertical_lines: List[ColumnLine] | |
heatmap: Any | |
affinity_map: Any | |
image_bbox: List[float] | |
class LayoutResult(BaseModel): | |
bboxes: List[LayoutBox] | |
segmentation_map: Any | |
image_bbox: List[float] | |
class OrderResult(BaseModel): | |
bboxes: List[OrderBox] | |
image_bbox: List[float] | |
class TableCell(Bbox): | |
row_id: int | None = None | |
col_id: int | None = None | |
text: str | None = None | |
class TableResult(BaseModel): | |
cells: List[TableCell] | |
rows: List[TableCell] | |
cols: List[TableCell] | |
image_bbox: List[float] | |