SuryaOCR / surya /schema.py
Jiangxz01's picture
Upload 56 files
52f1bcb verified
raw
history blame
5.47 kB
import copy
from typing import List, Tuple, Any, Optional
from pydantic import BaseModel, field_validator, computed_field
from surya.postprocessing.util import rescale_bbox
class PolygonBox(BaseModel):
polygon: List[List[float]]
confidence: Optional[float] = None
@field_validator('polygon')
@classmethod
def check_elements(cls, v: List[List[float]]) -> List[List[float]]:
if len(v) != 4:
raise ValueError('corner must have 4 elements')
for corner in v:
if len(corner) != 2:
raise ValueError('corner must have 2 elements')
return v
@property
def height(self):
return self.bbox[3] - self.bbox[1]
@property
def width(self):
return self.bbox[2] - self.bbox[0]
@property
def area(self):
return self.width * self.height
@computed_field
@property
def bbox(self) -> List[float]:
box = [self.polygon[0][0], self.polygon[0][1], self.polygon[1][0], self.polygon[2][1]]
if box[0] > box[2]:
box[0], box[2] = box[2], box[0]
if box[1] > box[3]:
box[1], box[3] = box[3], box[1]
return box
def rescale(self, processor_size, image_size):
# Point is in x, y format
page_width, page_height = processor_size
img_width, img_height = image_size
width_scaler = img_width / page_width
height_scaler = img_height / page_height
new_corners = copy.deepcopy(self.polygon)
for corner in new_corners:
corner[0] = int(corner[0] * width_scaler)
corner[1] = int(corner[1] * height_scaler)
self.polygon = new_corners
def fit_to_bounds(self, bounds):
new_corners = copy.deepcopy(self.polygon)
for corner in new_corners:
corner[0] = max(min(corner[0], bounds[2]), bounds[0])
corner[1] = max(min(corner[1], bounds[3]), bounds[1])
self.polygon = new_corners
def merge(self, other):
x1 = min(self.bbox[0], other.bbox[0])
y1 = min(self.bbox[1], other.bbox[1])
x2 = max(self.bbox[2], other.bbox[2])
y2 = max(self.bbox[3], other.bbox[3])
self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
def intersection_area(self, other, x_margin=0, y_margin=0):
x_overlap = max(0, min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin))
y_overlap = max(0, min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin))
return x_overlap * y_overlap
def intersection_pct(self, other, x_margin=0, y_margin=0):
assert 0 <= x_margin <= 1
assert 0 <= y_margin <= 1
if self.area == 0:
return 0
if x_margin:
x_margin = int(min(self.width, other.width) * x_margin)
if y_margin:
y_margin = int(min(self.height, other.height) * y_margin)
intersection = self.intersection_area(other, x_margin, y_margin)
return intersection / self.area
class Bbox(BaseModel):
bbox: List[float]
@field_validator('bbox')
@classmethod
def check_4_elements(cls, v: List[float]) -> List[float]:
if len(v) != 4:
raise ValueError('bbox must have 4 elements')
return v
def rescale_bbox(self, orig_size, new_size):
self.bbox = rescale_bbox(self.bbox, orig_size, new_size)
def round_bbox(self, divisor):
self.bbox = [x // divisor * divisor for x in self.bbox]
@property
def height(self):
return self.bbox[3] - self.bbox[1]
@property
def width(self):
return self.bbox[2] - self.bbox[0]
@property
def area(self):
return self.width * self.height
@property
def polygon(self):
return [[self.bbox[0], self.bbox[1]], [self.bbox[2], self.bbox[1]], [self.bbox[2], self.bbox[3]], [self.bbox[0], self.bbox[3]]]
@property
def center(self):
return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2]
def intersection_pct(self, other):
if self.area == 0:
return 0
x_overlap = max(0, min(self.bbox[2], other.bbox[2]) - max(self.bbox[0], other.bbox[0]))
y_overlap = max(0, min(self.bbox[3], other.bbox[3]) - max(self.bbox[1], other.bbox[1]))
intersection = x_overlap * y_overlap
return intersection / self.area
class LayoutBox(PolygonBox):
label: str
class OrderBox(Bbox):
position: int
class ColumnLine(Bbox):
vertical: bool
horizontal: bool
class TextLine(PolygonBox):
text: str
confidence: Optional[float] = None
class OCRResult(BaseModel):
text_lines: List[TextLine]
languages: List[str] | None = None
image_bbox: List[float]
class TextDetectionResult(BaseModel):
bboxes: List[PolygonBox]
vertical_lines: List[ColumnLine]
heatmap: Any
affinity_map: Any
image_bbox: List[float]
class LayoutResult(BaseModel):
bboxes: List[LayoutBox]
segmentation_map: Any
image_bbox: List[float]
class OrderResult(BaseModel):
bboxes: List[OrderBox]
image_bbox: List[float]
class TableCell(Bbox):
row_id: int | None = None
col_id: int | None = None
text: str | None = None
class TableResult(BaseModel):
cells: List[TableCell]
rows: List[TableCell]
cols: List[TableCell]
image_bbox: List[float]