diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..50d28409cb378495803726ace4ed152313ab22a5 --- /dev/null +++ b/app.py @@ -0,0 +1,13 @@ +import argparse +import subprocess +import os + + +def run_app(): + cur_dir = os.path.dirname(os.path.abspath(__file__)) + ocr_app_path = os.path.join(cur_dir, "ocr_app.py") + cmd = ["streamlit", "run", ocr_app_path] + subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"}) + +if __name__ == "__main__": + run_app() \ No newline at end of file diff --git a/detect_layout.py b/detect_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..39cf2ef6f4f99c1f366e58e8222f27c3b88bf8fb --- /dev/null +++ b/detect_layout.py @@ -0,0 +1,67 @@ +import pypdfium2 # Causes a warning if not the top import +import argparse +import copy +import json +from collections import defaultdict + +from surya.detection import batch_text_detection +from surya.input.load import load_from_folder, load_from_file +from surya.layout import batch_layout_detection +from surya.model.detection.model import load_model, load_processor +from surya.postprocessing.heatmap import draw_polys_on_image +from surya.settings import settings +import os + + +def main(): + parser = argparse.ArgumentParser(description="Detect layout of an input file or folder (PDFs or image).") + parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect layout in.") + parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya")) + parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) + parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False) + parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False) + args = parser.parse_args() + + model = load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + processor = load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + det_model = load_model() + det_processor = load_processor() + + if os.path.isdir(args.input_path): + images, names, _ = load_from_folder(args.input_path, args.max) + folder_name = os.path.basename(args.input_path) + else: + images, names, _ = load_from_file(args.input_path, args.max) + folder_name = os.path.basename(args.input_path).split(".")[0] + + line_predictions = batch_text_detection(images, det_model, det_processor) + + layout_predictions = batch_layout_detection(images, model, processor, line_predictions) + result_path = os.path.join(args.results_dir, folder_name) + os.makedirs(result_path, exist_ok=True) + + if args.images: + for idx, (image, layout_pred, name) in enumerate(zip(images, layout_predictions, names)): + polygons = [p.polygon for p in layout_pred.bboxes] + labels = [p.label for p in layout_pred.bboxes] + bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image), labels=labels) + bbox_image.save(os.path.join(result_path, f"{name}_{idx}_layout.png")) + + if args.debug: + heatmap = layout_pred.segmentation_map + heatmap.save(os.path.join(result_path, f"{name}_{idx}_segmentation.png")) + + predictions_by_page = defaultdict(list) + for idx, (pred, name, image) in enumerate(zip(layout_predictions, names, images)): + out_pred = pred.model_dump(exclude=["segmentation_map"]) + out_pred["page"] = len(predictions_by_page[name]) + 1 + predictions_by_page[name].append(out_pred) + + with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + json.dump(predictions_by_page, f, ensure_ascii=False) + + print(f"Wrote results to {result_path}") + + +if __name__ == "__main__": + main() diff --git a/detect_text.py b/detect_text.py new file mode 100644 index 0000000000000000000000000000000000000000..e7b0e5a6cec17485f1ecb7a60691abe98b38f973 --- /dev/null +++ b/detect_text.py @@ -0,0 +1,81 @@ +import argparse +import copy +import json +import time +from collections import defaultdict + +from surya.input.load import load_from_folder, load_from_file +from surya.model.detection.model import load_model, load_processor +from surya.detection import batch_text_detection +from surya.postprocessing.affinity import draw_lines_on_image +from surya.postprocessing.heatmap import draw_polys_on_image +from surya.settings import settings +import os +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).") + parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.") + parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya")) + parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) + parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False) + parser.add_argument("--debug", action="store_true", help="Run in debug mode.", default=False) + args = parser.parse_args() + + checkpoint = settings.DETECTOR_MODEL_CHECKPOINT + model = load_model(checkpoint=checkpoint) + processor = load_processor(checkpoint=checkpoint) + + if os.path.isdir(args.input_path): + images, names, _ = load_from_folder(args.input_path, args.max) + folder_name = os.path.basename(args.input_path) + else: + images, names, _ = load_from_file(args.input_path, args.max) + folder_name = os.path.basename(args.input_path).split(".")[0] + + start = time.time() + predictions = batch_text_detection(images, model, processor) + result_path = os.path.join(args.results_dir, folder_name) + os.makedirs(result_path, exist_ok=True) + end = time.time() + if args.debug: + print(f"Detection took {end - start} seconds") + + if args.images: + for idx, (image, pred, name) in enumerate(zip(images, predictions, names)): + polygons = [p.polygon for p in pred.bboxes] + bbox_image = draw_polys_on_image(polygons, copy.deepcopy(image)) + bbox_image.save(os.path.join(result_path, f"{name}_{idx}_bbox.png")) + + column_image = draw_lines_on_image(pred.vertical_lines, copy.deepcopy(image)) + column_image.save(os.path.join(result_path, f"{name}_{idx}_column.png")) + + if args.debug: + heatmap = pred.heatmap + heatmap.save(os.path.join(result_path, f"{name}_{idx}_heat.png")) + + affinity_map = pred.affinity_map + affinity_map.save(os.path.join(result_path, f"{name}_{idx}_affinity.png")) + + predictions_by_page = defaultdict(list) + for idx, (pred, name, image) in enumerate(zip(predictions, names, images)): + out_pred = pred.model_dump(exclude=["heatmap", "affinity_map"]) + out_pred["page"] = len(predictions_by_page[name]) + 1 + predictions_by_page[name].append(out_pred) + + with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + json.dump(predictions_by_page, f, ensure_ascii=False) + + print(f"Wrote results to {result_path}") + + +if __name__ == "__main__": + main() + + + + + + + diff --git a/ocr_app.py b/ocr_app.py new file mode 100644 index 0000000000000000000000000000000000000000..cb6791de32054b55205e59b01f054b36877e1306 --- /dev/null +++ b/ocr_app.py @@ -0,0 +1,257 @@ +import io +from typing import List + +import pypdfium2 +import streamlit as st +from pypdfium2 import PdfiumError + +from surya.detection import batch_text_detection +from surya.input.pdflines import get_page_text_lines, get_table_blocks +from surya.layout import batch_layout_detection +from surya.model.detection.model import load_model, load_processor +from surya.model.recognition.model import load_model as load_rec_model +from surya.model.recognition.processor import load_processor as load_rec_processor +from surya.model.ordering.processor import load_processor as load_order_processor +from surya.model.ordering.model import load_model as load_order_model +from surya.model.table_rec.model import load_model as load_table_model +from surya.model.table_rec.processor import load_processor as load_table_processor +from surya.ordering import batch_ordering +from surya.postprocessing.heatmap import draw_polys_on_image, draw_bboxes_on_image +from surya.ocr import run_ocr +from surya.postprocessing.text import draw_text_on_image +from PIL import Image +from surya.languages import CODE_TO_LANGUAGE +from surya.input.langs import replace_lang_with_code +from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult, TableResult +from surya.settings import settings +from surya.tables import batch_table_recognition +from surya.postprocessing.util import rescale_bboxes, rescale_bbox + + +@st.cache_resource() +def load_det_cached(): + checkpoint = settings.DETECTOR_MODEL_CHECKPOINT + return load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint) + + +@st.cache_resource() +def load_rec_cached(): + return load_rec_model(), load_rec_processor() + + +@st.cache_resource() +def load_layout_cached(): + return load_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT), load_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + +@st.cache_resource() +def load_order_cached(): + return load_order_model(), load_order_processor() + + +@st.cache_resource() +def load_table_cached(): + return load_table_model(), load_table_processor() + + +def text_detection(img) -> (Image.Image, TextDetectionResult): + pred = batch_text_detection([img], det_model, det_processor)[0] + polygons = [p.polygon for p in pred.bboxes] + det_img = draw_polys_on_image(polygons, img.copy()) + return det_img, pred + + +def layout_detection(img) -> (Image.Image, LayoutResult): + _, det_pred = text_detection(img) + pred = batch_layout_detection([img], layout_model, layout_processor, [det_pred])[0] + polygons = [p.polygon for p in pred.bboxes] + labels = [p.label for p in pred.bboxes] + layout_img = draw_polys_on_image(polygons, img.copy(), labels=labels, label_font_size=18) + return layout_img, pred + + +def order_detection(img) -> (Image.Image, OrderResult): + _, layout_pred = layout_detection(img) + bboxes = [l.bbox for l in layout_pred.bboxes] + pred = batch_ordering([img], [bboxes], order_model, order_processor)[0] + polys = [l.polygon for l in pred.bboxes] + positions = [str(l.position) for l in pred.bboxes] + order_img = draw_polys_on_image(polys, img.copy(), labels=positions, label_font_size=18) + return order_img, pred + + +def table_recognition(img, highres_img, filepath, page_idx: int, use_pdf_boxes: bool, skip_table_detection: bool) -> (Image.Image, List[TableResult]): + if skip_table_detection: + layout_tables = [(0, 0, highres_img.size[0], highres_img.size[1])] + table_imgs = [highres_img] + else: + _, layout_pred = layout_detection(img) + layout_tables_lowres = [l.bbox for l in layout_pred.bboxes if l.label == "Table"] + table_imgs = [] + layout_tables = [] + for tb in layout_tables_lowres: + highres_bbox = rescale_bbox(tb, img.size, highres_img.size) + table_imgs.append( + highres_img.crop(highres_bbox) + ) + layout_tables.append(highres_bbox) + + try: + page_text = get_page_text_lines(filepath, [page_idx], [highres_img.size])[0] + table_bboxes = get_table_blocks(layout_tables, page_text, highres_img.size) + except PdfiumError: + # This happens when we try to get text from an image + table_bboxes = [[] for _ in layout_tables] + + if not use_pdf_boxes or any(len(tb) == 0 for tb in table_bboxes): + det_results = batch_text_detection(table_imgs, det_model, det_processor) + table_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results] + + table_preds = batch_table_recognition(table_imgs, table_bboxes, table_model, table_processor) + table_img = img.copy() + + for results, table_bbox in zip(table_preds, layout_tables): + adjusted_bboxes = [] + labels = [] + + for item in results.cells: + adjusted_bboxes.append([ + (item.bbox[0] + table_bbox[0]), + (item.bbox[1] + table_bbox[1]), + (item.bbox[2] + table_bbox[0]), + (item.bbox[3] + table_bbox[1]) + ]) + labels.append(f"{item.row_id} / {item.col_id}") + table_img = draw_bboxes_on_image(adjusted_bboxes, highres_img, labels=labels, label_font_size=18) + return table_img, table_preds + + +# Function for OCR +def ocr(img, highres_img, langs: List[str]) -> (Image.Image, OCRResult): + replace_lang_with_code(langs) + img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor, highres_images=[highres_img])[0] + + bboxes = [l.bbox for l in img_pred.text_lines] + text = [l.text for l in img_pred.text_lines] + rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs) + return rec_img, img_pred + + +def open_pdf(pdf_file): + stream = io.BytesIO(pdf_file.getvalue()) + return pypdfium2.PdfDocument(stream) + + +@st.cache_data() +def get_page_image(pdf_file, page_num, dpi=settings.IMAGE_DPI): + doc = open_pdf(pdf_file) + renderer = doc.render( + pypdfium2.PdfBitmap.to_pil, + page_indices=[page_num - 1], + scale=dpi / 72, + ) + png = list(renderer)[0] + png_image = png.convert("RGB") + return png_image + + +@st.cache_data() +def page_count(pdf_file): + doc = open_pdf(pdf_file) + return len(doc) + + +st.set_page_config(layout="wide") +col1, col2 = st.columns([.5, .5]) + +det_model, det_processor = load_det_cached() +rec_model, rec_processor = load_rec_cached() +layout_model, layout_processor = load_layout_cached() +order_model, order_processor = load_order_cached() +table_model, table_processor = load_table_cached() + + +st.markdown(""" +# Surya OCR Demo + +This app will let you try surya, a multilingual OCR model. It supports text detection + layout analysis in any language, and text recognition in 90+ languages. + +Notes: +- This works best on documents with printed text. +- Preprocessing the image (e.g. increasing contrast) can improve results. +- If OCR doesn't work, try changing the resolution of your image (increase if below 2048px width, otherwise decrease). +- This supports 90+ languages, see [here](https://github.com/VikParuchuri/surya/tree/master/surya/languages.py) for a full list. + +Find the project [here](https://github.com/VikParuchuri/surya). +""") + +in_file = st.sidebar.file_uploader("PDF file or image:", type=["pdf", "png", "jpg", "jpeg", "gif", "webp"]) +languages = st.sidebar.multiselect("Languages", sorted(list(CODE_TO_LANGUAGE.values())), default=[], max_selections=4, help="Select the languages in the image (if known) to improve OCR accuracy. Optional.") + +if in_file is None: + st.stop() + +filetype = in_file.type +whole_image = False +if "pdf" in filetype: + page_count = page_count(in_file) + page_number = st.sidebar.number_input(f"Page number out of {page_count}:", min_value=1, value=1, max_value=page_count) + + pil_image = get_page_image(in_file, page_number, settings.IMAGE_DPI) + pil_image_highres = get_page_image(in_file, page_number, dpi=settings.IMAGE_DPI_HIGHRES) +else: + pil_image = Image.open(in_file).convert("RGB") + pil_image_highres = pil_image + page_number = None + +text_det = st.sidebar.button("Run Text Detection") +text_rec = st.sidebar.button("Run OCR") +layout_det = st.sidebar.button("Run Layout Analysis") +order_det = st.sidebar.button("Run Reading Order") +table_rec = st.sidebar.button("Run Table Rec") +use_pdf_boxes = st.sidebar.checkbox("PDF table boxes", value=True, help="Table recognition only: Use the bounding boxes from the PDF file vs text detection model.") +skip_table_detection = st.sidebar.checkbox("Skip table detection", value=False, help="Table recognition only: Skip table detection and treat the whole image/page as a table.") + +if pil_image is None: + st.stop() + +# Run Text Detection +if text_det: + det_img, pred = text_detection(pil_image) + with col1: + st.image(det_img, caption="Detected Text", use_column_width=True) + st.json(pred.model_dump(exclude=["heatmap", "affinity_map"]), expanded=True) + + +# Run layout +if layout_det: + layout_img, pred = layout_detection(pil_image) + with col1: + st.image(layout_img, caption="Detected Layout", use_column_width=True) + st.json(pred.model_dump(exclude=["segmentation_map"]), expanded=True) + +# Run OCR +if text_rec: + rec_img, pred = ocr(pil_image, pil_image_highres, languages) + with col1: + st.image(rec_img, caption="OCR Result", use_column_width=True) + json_tab, text_tab = st.tabs(["JSON", "Text Lines (for debugging)"]) + with json_tab: + st.json(pred.model_dump(), expanded=True) + with text_tab: + st.text("\n".join([p.text for p in pred.text_lines])) + +if order_det: + order_img, pred = order_detection(pil_image) + with col1: + st.image(order_img, caption="Reading Order", use_column_width=True) + st.json(pred.model_dump(), expanded=True) + + +if table_rec: + table_img, pred = table_recognition(pil_image, pil_image_highres, in_file, page_number - 1 if page_number else None, use_pdf_boxes, skip_table_detection) + with col1: + st.image(table_img, caption="Table Recognition", use_column_width=True) + st.json([p.model_dump() for p in pred], expanded=True) + +with col2: + st.image(pil_image, caption="Uploaded Image", use_column_width=True) \ No newline at end of file diff --git a/ocr_text.py b/ocr_text.py new file mode 100644 index 0000000000000000000000000000000000000000..be24f6f77203ac7e8e39b95db5b44530bde09e77 --- /dev/null +++ b/ocr_text.py @@ -0,0 +1,98 @@ +import os +import argparse +import json +import time +from collections import defaultdict + +import torch + +from surya.input.langs import replace_lang_with_code, get_unique_langs +from surya.input.load import load_from_folder, load_from_file, load_lang_file +from surya.model.detection.model import load_model as load_detection_model, load_processor as load_detection_processor +from surya.model.recognition.model import load_model as load_recognition_model +from surya.model.recognition.processor import load_processor as load_recognition_processor +from surya.model.recognition.tokenizer import _tokenize +from surya.ocr import run_ocr +from surya.postprocessing.text import draw_text_on_image +from surya.settings import settings + + +def main(): + parser = argparse.ArgumentParser(description="Detect bboxes in an input file or folder (PDFs or image).") + parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to detect bboxes in.") + parser.add_argument("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "surya")) + parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) + parser.add_argument("--start_page", type=int, help="Page to start processing at.", default=0) + parser.add_argument("--images", action="store_true", help="Save images of detected bboxes.", default=False) + parser.add_argument("--langs", type=str, help="Optional language(s) to use for OCR. Comma separate for multiple. Can be a capitalized language name, or a 2-letter ISO 639 code.", default=None) + parser.add_argument("--lang_file", type=str, help="Optional path to file with languages to use for OCR. Should be a JSON dict with file names as keys, and the value being a list of language codes/names.", default=None) + parser.add_argument("--debug", action="store_true", help="Enable debug logging.", default=False) + args = parser.parse_args() + + if os.path.isdir(args.input_path): + images, names, _ = load_from_folder(args.input_path, args.max, args.start_page) + highres_images, _, _ = load_from_folder(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES) + folder_name = os.path.basename(args.input_path) + else: + images, names, _ = load_from_file(args.input_path, args.max, args.start_page) + highres_images, _, _ = load_from_file(args.input_path, args.max, args.start_page, settings.IMAGE_DPI_HIGHRES) + folder_name = os.path.basename(args.input_path).split(".")[0] + + if args.lang_file: + # We got all of our language settings from a file + langs = load_lang_file(args.lang_file, names) + for lang in langs: + replace_lang_with_code(lang) + image_langs = langs + elif args.langs: + # We got our language settings from the input + langs = args.langs.split(",") + replace_lang_with_code(langs) + image_langs = [langs] * len(images) + else: + image_langs = [None] * len(images) + + det_processor = load_detection_processor() + det_model = load_detection_model() + + rec_model = load_recognition_model() + rec_processor = load_recognition_processor() + + result_path = os.path.join(args.results_dir, folder_name) + os.makedirs(result_path, exist_ok=True) + + start = time.time() + predictions_by_image = run_ocr(images, image_langs, det_model, det_processor, rec_model, rec_processor, highres_images=highres_images) + if args.debug: + print(f"OCR took {time.time() - start:.2f} seconds") + max_chars = max([len(l.text) for p in predictions_by_image for l in p.text_lines]) + print(f"Max chars: {max_chars}") + + if args.images: + for idx, (name, image, pred, langs) in enumerate(zip(names, images, predictions_by_image, image_langs)): + bboxes = [l.bbox for l in pred.text_lines] + pred_text = [l.text for l in pred.text_lines] + page_image = draw_text_on_image(bboxes, pred_text, image.size, langs, has_math="_math" in langs if langs else False) + page_image.save(os.path.join(result_path, f"{name}_{idx}_text.png")) + + out_preds = defaultdict(list) + for name, pred, image in zip(names, predictions_by_image, images): + out_pred = pred.model_dump() + out_pred["page"] = len(out_preds[name]) + 1 + out_preds[name].append(out_pred) + + with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + json.dump(out_preds, f, ensure_ascii=False) + + print(f"Wrote results to {result_path}") + + +if __name__ == "__main__": + main() + + + + + + + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..dc5105c2cbc6661099e3f64cfcd635f79e9be416 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,59 @@ +[tool.poetry] +name = "surya-ocr" +version = "0.6.1" +description = "OCR, layout, reading order, and table recognition in 90+ languages" +authors = ["Vik Paruchuri "] +readme = "README.md" +license = "GPL-3.0-or-later" +repository = "https://github.com/VikParuchuri/surya" +keywords = ["ocr", "pdf", "text detection", "text recognition", "tables"] +packages = [ + {include = "surya"} +] +include = [ + "detect_text.py", + "ocr_text.py", + "ocr_app.py", + "run_ocr_app.py", + "detect_layout.py", + "reading_order.py", + "table_recognition.py" +] + +[tool.poetry.dependencies] +python = ">=3.9,<3.13,!=3.9.7" +transformers = "^4.41.0" +torch = "^2.3.0" +pydantic = "^2.5.3" +pydantic-settings = "^2.1.0" +python-dotenv = "^1.0.0" +pillow = "^10.2.0" +pypdfium2 = "^4.25.0" +opencv-python = "^4.9.0.80" +tabulate = "^0.9.0" +filetype = "^1.2.0" +ftfy = "^6.1.3" +pdftext = "^0.3.12" + +[tool.poetry.group.dev.dependencies] +jupyter = "^1.0.0" +pytesseract = "^0.3.10" +pymupdf = "^1.23.8" +snakeviz = "^2.2.0" +datasets = "^2.16.1" +rapidfuzz = "^3.6.1" +arabic-reshaper = "^3.0.0" +streamlit = "^1.31.0" +playwright = "^1.41.2" + +[tool.poetry.scripts] +surya_detect = "detect_text:main" +surya_ocr = "ocr_text:main" +surya_layout = "detect_layout:main" +surya_gui = "run_ocr_app:run_app" +surya_order = "reading_order:main" +surya_table = "table_recognition:main" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/reading_order.py b/reading_order.py new file mode 100644 index 0000000000000000000000000000000000000000..4277a8aa2fd44dba274a1227961513fcbdc8cf19 --- /dev/null +++ b/reading_order.py @@ -0,0 +1,81 @@ +import os +import argparse +import copy +import json +from collections import defaultdict + +from surya.detection import batch_text_detection +from surya.input.load import load_from_folder, load_from_file +from surya.layout import batch_layout_detection +from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor +from surya.model.ordering.model import load_model +from surya.model.ordering.processor import load_processor +from surya.ordering import batch_ordering +from surya.postprocessing.heatmap import draw_polys_on_image +from surya.settings import settings + + +def main(): + parser = argparse.ArgumentParser(description="Find reading order of an input file or folder (PDFs or image).") + parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to find reading order in.") + parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya")) + parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) + parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False) + args = parser.parse_args() + + model = load_model() + processor = load_processor() + + layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + + det_model = load_det_model() + det_processor = load_det_processor() + + if os.path.isdir(args.input_path): + images, names, _ = load_from_folder(args.input_path, args.max) + folder_name = os.path.basename(args.input_path) + else: + images, names, _ = load_from_file(args.input_path, args.max) + folder_name = os.path.basename(args.input_path).split(".")[0] + + line_predictions = batch_text_detection(images, det_model, det_processor) + layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions) + bboxes = [] + for layout_pred in layout_predictions: + bbox = [l.bbox for l in layout_pred.bboxes] + bboxes.append(bbox) + + order_predictions = batch_ordering(images, bboxes, model, processor) + result_path = os.path.join(args.results_dir, folder_name) + os.makedirs(result_path, exist_ok=True) + + if args.images: + for idx, (image, layout_pred, order_pred, name) in enumerate(zip(images, layout_predictions, order_predictions, names)): + polys = [l.polygon for l in order_pred.bboxes] + labels = [str(l.position) for l in order_pred.bboxes] + bbox_image = draw_polys_on_image(polys, copy.deepcopy(image), labels=labels, label_font_size=20) + bbox_image.save(os.path.join(result_path, f"{name}_{idx}_order.png")) + + predictions_by_page = defaultdict(list) + for idx, (layout_pred, pred, name, image) in enumerate(zip(layout_predictions, order_predictions, names, images)): + out_pred = pred.model_dump() + for bbox, layout_bbox in zip(out_pred["bboxes"], layout_pred.bboxes): + bbox["label"] = layout_bbox.label + + out_pred["page"] = len(predictions_by_page[name]) + 1 + predictions_by_page[name].append(out_pred) + + # Sort in reading order + for name in predictions_by_page: + for page_preds in predictions_by_page[name]: + page_preds["bboxes"] = sorted(page_preds["bboxes"], key=lambda x: x["position"]) + + with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + json.dump(predictions_by_page, f, ensure_ascii=False) + + print(f"Wrote results to {result_path}") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..69b6b58e3feb87793e9c4b12c681aabf035360cf --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +streamlit +torch +torchvision +torchaudio +surya-ocr \ No newline at end of file diff --git a/scripts/verify_benchmark_scores.py b/scripts/verify_benchmark_scores.py new file mode 100644 index 0000000000000000000000000000000000000000..956aaae9a04b5d4ba7de709aedc1734d5355192e --- /dev/null +++ b/scripts/verify_benchmark_scores.py @@ -0,0 +1,61 @@ +import json +import argparse + + +def verify_layout(data): + scores = data["metrics"] + for layout_type, metrics in scores.items(): + if metrics["precision"] <= 0.6 or metrics["recall"] <= 0.6: + raise ValueError("Scores do not meet the required threshold") + + +def verify_det(data): + scores = data["metrics"]["surya"] + if scores["precision"] <= 0.9 or scores["recall"] <= 0.9: + raise ValueError("Scores do not meet the required threshold") + + +def verify_rec(data): + scores = data["surya"] + if scores["avg_score"] <= 0.9: + raise ValueError("Scores do not meet the required threshold") + + +def verify_order(data): + score = data["mean_accuracy"] + if score < 0.75: + raise ValueError("Scores do not meet the required threshold") + + +def verify_table_rec(data): + row_score = data["surya"]["mean_row_iou"] + col_score = data["surya"]["mean_col_iou"] + + if row_score < 0.75 or col_score < 0.75: + raise ValueError("Scores do not meet the required threshold") + + +def verify_scores(file_path, bench_type): + with open(file_path, 'r') as file: + data = json.load(file) + + if bench_type == "detection": + verify_det(data) + elif bench_type == "recognition": + verify_rec(data) + elif bench_type == "layout": + verify_layout(data) + elif bench_type == "ordering": + verify_order(data) + elif bench_type == "table_recognition": + verify_table_rec(data) + else: + raise ValueError("Invalid benchmark type") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Verify benchmark scores") + parser.add_argument("file_path", type=str, help="Path to the json file") + parser.add_argument("--bench_type", type=str, help="Type of benchmark to verify", default="detection") + args = parser.parse_args() + verify_scores(args.file_path, args.bench_type) diff --git a/surya/benchmark/bbox.py b/surya/benchmark/bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..b7593e836dd4cc3402557d51e1ac55773da06849 --- /dev/null +++ b/surya/benchmark/bbox.py @@ -0,0 +1,22 @@ +import fitz as pymupdf +from surya.postprocessing.util import rescale_bbox + + +def get_pdf_lines(pdf_path, img_sizes): + doc = pymupdf.open(pdf_path) + page_lines = [] + for idx, img_size in enumerate(img_sizes): + page = doc[idx] + blocks = page.get_text("dict", sort=True, flags=pymupdf.TEXTFLAGS_DICT & ~pymupdf.TEXT_PRESERVE_LIGATURES & ~pymupdf.TEXT_PRESERVE_IMAGES)["blocks"] + + line_boxes = [] + for block_idx, block in enumerate(blocks): + for l in block["lines"]: + line_boxes.append(list(l["bbox"])) + + page_box = page.bound() + pwidth, pheight = page_box[2] - page_box[0], page_box[3] - page_box[1] + line_boxes = [rescale_bbox(bbox, (pwidth, pheight), img_size) for bbox in line_boxes] + page_lines.append(line_boxes) + + return page_lines \ No newline at end of file diff --git a/surya/benchmark/metrics.py b/surya/benchmark/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..15827690e3bb5322624fcf157383ab584a23af5f --- /dev/null +++ b/surya/benchmark/metrics.py @@ -0,0 +1,193 @@ +from functools import partial +from itertools import repeat + +import numpy as np +from concurrent.futures import ProcessPoolExecutor + + +def intersection_area(box1, box2): + x_left = max(box1[0], box2[0]) + y_top = max(box1[1], box2[1]) + x_right = min(box1[2], box2[2]) + y_bottom = min(box1[3], box2[3]) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + return (x_right - x_left) * (y_bottom - y_top) + +def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + +def calculate_iou(box1, box2, box1_only=False): + intersection = intersection_area(box1, box2) + union = box_area(box1) + if not box1_only: + union += box_area(box2) - intersection + + if union == 0: + return 0 + return intersection / union + + +def match_boxes(preds, references): + num_actual = len(references) + num_predicted = len(preds) + + iou_matrix = np.zeros((num_actual, num_predicted)) + for i, actual in enumerate(references): + for j, pred in enumerate(preds): + iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True) + + sorted_indices = np.argsort(iou_matrix, axis=None)[::-1] + sorted_ious = iou_matrix.flatten()[sorted_indices] + actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape) + + assigned_actual = set() + assigned_pred = set() + + matches = [] + for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious): + i, j = idx + if i not in assigned_actual and j not in assigned_pred: + iou_val = iou_matrix[i, j] + if iou_val > .95: # Account for rounding on box edges + iou_val = 1.0 + matches.append((i, j, iou_val)) + assigned_actual.add(i) + assigned_pred.add(j) + + unassigned_actual = set(range(num_actual)) - assigned_actual + unassigned_pred = set(range(num_predicted)) - assigned_pred + matches.extend([(i, None, -1.0) for i in unassigned_actual]) + matches.extend([(None, j, 0.0) for j in unassigned_pred]) + + return matches + +def penalized_iou_score(preds, references): + matches = match_boxes(preds, references) + iou = sum([match[2] for match in matches]) / len(matches) + return iou + +def intersection_pixels(box1, box2): + x_left = max(box1[0], box2[0]) + y_top = max(box1[1], box2[1]) + x_right = min(box1[2], box2[2]) + y_bottom = min(box1[3], box2[3]) + + if x_right < x_left or y_bottom < y_top: + return set() + + x_left, x_right = int(x_left), int(x_right) + y_top, y_bottom = int(y_top), int(y_bottom) + + coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom)) + pixels = set(zip(coords[0].flat, coords[1].flat)) + + return pixels + + +def calculate_coverage(box, other_boxes, penalize_double=False): + box_area = (box[2] - box[0]) * (box[3] - box[1]) + if box_area == 0: + return 0 + + # find total coverage of the box + covered_pixels = set() + double_coverage = list() + for other_box in other_boxes: + ia = intersection_pixels(box, other_box) + double_coverage.append(list(covered_pixels.intersection(ia))) + covered_pixels = covered_pixels.union(ia) + + # Penalize double coverage - having multiple bboxes overlapping the same pixels + double_coverage_penalty = len(double_coverage) + if not penalize_double: + double_coverage_penalty = 0 + covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty) + return covered_pixels_count / box_area + + +def calculate_coverage_fast(box, other_boxes, penalize_double=False): + box_area = (box[2] - box[0]) * (box[3] - box[1]) + if box_area == 0: + return 0 + + total_intersect = 0 + for other_box in other_boxes: + total_intersect += intersection_area(box, other_box) + + return min(1, total_intersect / box_area) + + +def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True): + if len(references) == 0: + return { + "precision": 1, + "recall": 1, + } + + if len(preds) == 0: + return { + "precision": 0, + "recall": 0, + } + + # If we're not penalizing double coverage, we can use a faster calculation + coverage_func = calculate_coverage_fast + if penalize_double: + coverage_func = calculate_coverage + + with ProcessPoolExecutor(max_workers=workers) as executor: + precision_func = partial(coverage_func, penalize_double=penalize_double) + precision_iou = executor.map(precision_func, preds, repeat(references)) + reference_iou = executor.map(coverage_func, references, repeat(preds)) + + precision_classes = [1 if i > threshold else 0 for i in precision_iou] + precision = sum(precision_classes) / len(precision_classes) + + recall_classes = [1 if i > threshold else 0 for i in reference_iou] + recall = sum(recall_classes) / len(recall_classes) + + return { + "precision": precision, + "recall": recall, + } + + +def mean_coverage(preds, references): + coverages = [] + + for box1 in references: + coverage = calculate_coverage(box1, preds) + coverages.append(coverage) + + for box2 in preds: + coverage = calculate_coverage(box2, references) + coverages.append(coverage) + + # Calculate the average coverage over all comparisons + if len(coverages) == 0: + return 0 + coverage = sum(coverages) / len(coverages) + return {"coverage": coverage} + + +def rank_accuracy(preds, references): + # Preds and references need to be aligned so each position refers to the same bbox + pairs = [] + for i, pred in enumerate(preds): + for j, pred2 in enumerate(preds): + if i == j: + continue + pairs.append((i, j, pred > pred2)) + + # Find how many of the prediction rankings are correct + correct = 0 + for i, ref in enumerate(references): + for j, ref2 in enumerate(references): + if (i, j, ref > ref2) in pairs: + correct += 1 + + return correct / len(pairs) \ No newline at end of file diff --git a/surya/benchmark/tatr.py b/surya/benchmark/tatr.py new file mode 100644 index 0000000000000000000000000000000000000000..6c9d9f65ad7921708d58c3df81a98ec5fec95164 --- /dev/null +++ b/surya/benchmark/tatr.py @@ -0,0 +1,117 @@ +import torch +from transformers import DetrFeatureExtractor, AutoModelForObjectDetection +from surya.settings import settings + +from PIL import Image +import numpy as np + + +class MaxResize(object): + def __init__(self, max_size=800): + self.max_size = max_size + + def __call__(self, image): + width, height = image.size + current_max_size = max(width, height) + scale = self.max_size / current_max_size + resized_image = image.resize((int(round(scale * width)), int(round(scale * height)))) + + return resized_image + + +def to_tensor(image): + # Convert PIL Image to NumPy array + np_image = np.array(image).astype(np.float32) + + # Rearrange dimensions to [C, H, W] format + np_image = np_image.transpose((2, 0, 1)) + + # Normalize to [0.0, 1.0] + np_image /= 255.0 + + return torch.from_numpy(np_image) + + +def normalize(tensor, mean, std): + for t, m, s in zip(tensor, mean, std): + t.sub_(m).div_(s) + return tensor + + +def structure_transform(image): + image = MaxResize(1000)(image) + tensor = to_tensor(image) + normalized_tensor = normalize(tensor, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + return normalized_tensor + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=1) + + +def rescale_bboxes(out_bbox, size): + width, height = size + boxes = box_cxcywh_to_xyxy(out_bbox) + boxes = boxes * torch.tensor([width, height, width, height], dtype=torch.float32) + return boxes + + +def outputs_to_objects(outputs, img_sizes, id2label): + m = outputs.logits.softmax(-1).max(-1) + batch_labels = list(m.indices.detach().cpu().numpy()) + batch_scores = list(m.values.detach().cpu().numpy()) + batch_bboxes = outputs['pred_boxes'].detach().cpu() + + batch_objects = [] + for i in range(len(img_sizes)): + pred_bboxes = [elem.tolist() for elem in rescale_bboxes(batch_bboxes[i], img_sizes[i])] + pred_scores = batch_scores[i] + pred_labels = batch_labels[i] + + objects = [] + for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): + class_label = id2label[int(label)] + if not class_label == 'no object': + objects.append({ + 'label': class_label, + 'score': float(score), + 'bbox': [float(elem) for elem in bbox]} + ) + + rows = [] + cols = [] + for i, cell in enumerate(objects): + if cell["label"] == "table column": + cols.append(cell) + + if cell["label"] == "table row": + rows.append(cell) + batch_objects.append({ + "rows": rows, + "cols": cols + }) + + return batch_objects + + +def load_tatr(): + return AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all").to(settings.TORCH_DEVICE_MODEL) + + +def batch_inference_tatr(model, images, batch_size): + device = model.device + rows_cols = [] + for i in range(0, len(images), batch_size): + batch_images = images[i:i + batch_size] + pixel_values = torch.stack([structure_transform(img) for img in batch_images], dim=0).to(device) + + # forward pass + with torch.no_grad(): + outputs = model(pixel_values) + + id2label = model.config.id2label + id2label[len(model.config.id2label)] = "no object" + rows_cols.extend(outputs_to_objects(outputs, [img.size for img in batch_images], id2label)) + return rows_cols \ No newline at end of file diff --git a/surya/benchmark/tesseract.py b/surya/benchmark/tesseract.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d025e0f01fc9e1a3907817f1fcc70461fa42e2 --- /dev/null +++ b/surya/benchmark/tesseract.py @@ -0,0 +1,179 @@ +from typing import List, Optional + +import numpy as np +import pytesseract +from pytesseract import Output +from tqdm import tqdm + +from surya.input.processing import slice_bboxes_from_image +from surya.settings import settings +import os +from concurrent.futures import ProcessPoolExecutor +from surya.detection import get_batch_size as get_det_batch_size +from surya.recognition import get_batch_size as get_rec_batch_size +from surya.languages import CODE_TO_LANGUAGE + + +def surya_lang_to_tesseract(code: str) -> Optional[str]: + lang_str = CODE_TO_LANGUAGE[code] + try: + tess_lang = TESS_LANGUAGE_TO_CODE[lang_str] + except KeyError: + return None + return tess_lang + + +def tesseract_ocr(img, bboxes, lang: str): + line_imgs = slice_bboxes_from_image(img, bboxes) + config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"' + lines = [] + for line_img in line_imgs: + line = pytesseract.image_to_string(line_img, lang=lang, config=config) + lines.append(line) + return lines + + +def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None): + tess_parallel_cores = min(len(imgs), get_rec_batch_size()) + if not cpus: + cpus = os.cpu_count() + tess_parallel_cores = min(tess_parallel_cores, cpus) + + # Tesseract uses up to 4 processes per instance + # Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images + tess_parallel = max(tess_parallel_cores // 2, 1) + + with ProcessPoolExecutor(max_workers=tess_parallel) as executor: + tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR") + tess_text = list(tess_text) + return tess_text + + +def tesseract_bboxes(img): + arr_img = np.asarray(img, dtype=np.uint8) + ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT) + + bboxes = [] + n_boxes = len(ocr['level']) + for i in range(n_boxes): + # It is possible to merge by line here with line number, but it gives bad results. + _, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i] + bbox = (x, y, x + w, y + h) + bboxes.append(bbox) + + return bboxes + + +def tesseract_parallel(imgs): + # Tesseract uses 4 threads per instance + tess_parallel_cores = min(len(imgs), get_det_batch_size()) + cpus = os.cpu_count() + tess_parallel_cores = min(tess_parallel_cores, cpus) + + # Tesseract uses 4 threads per instance + tess_parallel = max(tess_parallel_cores // 4, 1) + + with ProcessPoolExecutor(max_workers=tess_parallel) as executor: + tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection") + tess_bboxes = list(tess_bboxes) + return tess_bboxes + + +TESS_CODE_TO_LANGUAGE = { + "afr": "Afrikaans", + "amh": "Amharic", + "ara": "Arabic", + "asm": "Assamese", + "aze": "Azerbaijani", + "bel": "Belarusian", + "ben": "Bengali", + "bod": "Tibetan", + "bos": "Bosnian", + "bre": "Breton", + "bul": "Bulgarian", + "cat": "Catalan", + "ceb": "Cebuano", + "ces": "Czech", + "chi_sim": "Chinese", + "chr": "Cherokee", + "cym": "Welsh", + "dan": "Danish", + "deu": "German", + "dzo": "Dzongkha", + "ell": "Greek", + "eng": "English", + "epo": "Esperanto", + "est": "Estonian", + "eus": "Basque", + "fas": "Persian", + "fin": "Finnish", + "fra": "French", + "fry": "Western Frisian", + "guj": "Gujarati", + "gla": "Scottish Gaelic", + "gle": "Irish", + "glg": "Galician", + "heb": "Hebrew", + "hin": "Hindi", + "hrv": "Croatian", + "hun": "Hungarian", + "hye": "Armenian", + "iku": "Inuktitut", + "ind": "Indonesian", + "isl": "Icelandic", + "ita": "Italian", + "jav": "Javanese", + "jpn": "Japanese", + "kan": "Kannada", + "kat": "Georgian", + "kaz": "Kazakh", + "khm": "Khmer", + "kir": "Kyrgyz", + "kor": "Korean", + "lao": "Lao", + "lat": "Latin", + "lav": "Latvian", + "lit": "Lithuanian", + "mal": "Malayalam", + "mar": "Marathi", + "mkd": "Macedonian", + "mlt": "Maltese", + "mon": "Mongolian", + "msa": "Malay", + "mya": "Burmese", + "nep": "Nepali", + "nld": "Dutch", + "nor": "Norwegian", + "ori": "Oriya", + "pan": "Punjabi", + "pol": "Polish", + "por": "Portuguese", + "pus": "Pashto", + "ron": "Romanian", + "rus": "Russian", + "san": "Sanskrit", + "sin": "Sinhala", + "slk": "Slovak", + "slv": "Slovenian", + "snd": "Sindhi", + "spa": "Spanish", + "sqi": "Albanian", + "srp": "Serbian", + "swa": "Swahili", + "swe": "Swedish", + "syr": "Syriac", + "tam": "Tamil", + "tel": "Telugu", + "tgk": "Tajik", + "tha": "Thai", + "tir": "Tigrinya", + "tur": "Turkish", + "uig": "Uyghur", + "ukr": "Ukrainian", + "urd": "Urdu", + "uzb": "Uzbek", + "vie": "Vietnamese", + "yid": "Yiddish" +} + +TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()} diff --git a/surya/benchmark/util.py b/surya/benchmark/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a32f470390845e4feffce6adcc58aa763749c29d --- /dev/null +++ b/surya/benchmark/util.py @@ -0,0 +1,31 @@ +def merge_boxes(box1, box2): + return (min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3])) + + +def join_lines(bboxes, max_gap=5): + to_merge = {} + for i, box1 in bboxes: + for z, box2 in bboxes[i + 1:]: + j = i + z + 1 + if box1 == box2: + continue + + if box1[0] <= box2[0] and box1[2] >= box2[2]: + if abs(box1[1] - box2[3]) <= max_gap: + if i not in to_merge: + to_merge[i] = [] + to_merge[i].append(j) + + merged_boxes = set() + merged = [] + for i, box in bboxes: + if i in merged_boxes: + continue + + if i in to_merge: + for j in to_merge[i]: + box = merge_boxes(box, bboxes[j][1]) + merged_boxes.add(j) + + merged.append(box) + return merged diff --git a/surya/detection.py b/surya/detection.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b4a282be126bdb356c8f1c3adcba7e3f747cb3 --- /dev/null +++ b/surya/detection.py @@ -0,0 +1,144 @@ +from typing import List, Tuple, Generator + +import torch +import numpy as np +from PIL import Image + +from surya.model.detection.model import EfficientViTForSemanticSegmentation +from surya.postprocessing.heatmap import get_and_clean_boxes +from surya.postprocessing.affinity import get_vertical_lines +from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb +from surya.schema import TextDetectionResult +from surya.settings import settings +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor +import torch.nn.functional as F + + +def get_batch_size(): + batch_size = settings.DETECTOR_BATCH_SIZE + if batch_size is None: + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "mps": + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "cuda": + batch_size = 36 + return batch_size + + +def batch_detection( + images: List, + model: EfficientViTForSemanticSegmentation, + processor, + batch_size=None +) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]: + assert all([isinstance(image, Image.Image) for image in images]) + if batch_size is None: + batch_size = get_batch_size() + heatmap_count = model.config.num_labels + + orig_sizes = [image.size for image in images] + splits_per_image = [get_total_splits(size, processor) for size in orig_sizes] + + batches = [] + current_batch_size = 0 + current_batch = [] + for i in range(len(images)): + if current_batch_size + splits_per_image[i] > batch_size: + if len(current_batch) > 0: + batches.append(current_batch) + current_batch = [] + current_batch_size = 0 + current_batch.append(i) + current_batch_size += splits_per_image[i] + + if len(current_batch) > 0: + batches.append(current_batch) + + for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"): + batch_image_idxs = batches[batch_idx] + batch_images = [images[j].convert("RGB") for j in batch_image_idxs] + + split_index = [] + split_heights = [] + image_splits = [] + for image_idx, image in enumerate(batch_images): + image_parts, split_height = split_image(image, processor) + image_splits.extend(image_parts) + split_index.extend([image_idx] * len(image_parts)) + split_heights.extend(split_height) + + image_splits = [prepare_image_detection(image, processor) for image in image_splits] + # Batch images in dim 0 + batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device) + + with torch.inference_mode(): + pred = model(pixel_values=batch) + + logits = pred.logits + correct_shape = [processor.size["height"], processor.size["width"]] + current_shape = list(logits.shape[2:]) + if current_shape != correct_shape: + logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False) + + logits = logits.cpu().detach().numpy().astype(np.float32) + preds = [] + for i, (idx, height) in enumerate(zip(split_index, split_heights)): + # If our current prediction length is below the image idx, that means we have a new image + # Otherwise, we need to add to the current image + if len(preds) <= idx: + preds.append([logits[i][k] for k in range(heatmap_count)]) + else: + heatmaps = preds[idx] + pred_heatmaps = [logits[i][k] for k in range(heatmap_count)] + + if height < processor.size["height"]: + # Cut off padding to get original height + pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps] + + for k in range(heatmap_count): + heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]]) + preds[idx] = heatmaps + + yield preds, [orig_sizes[j] for j in batch_image_idxs] + + +def parallel_get_lines(preds, orig_sizes): + heatmap, affinity_map = preds + heat_img = Image.fromarray((heatmap * 255).astype(np.uint8)) + aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8)) + affinity_size = list(reversed(affinity_map.shape)) + heatmap_size = list(reversed(heatmap.shape)) + bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes) + vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes) + + result = TextDetectionResult( + bboxes=bboxes, + vertical_lines=vertical_lines, + heatmap=heat_img, + affinity_map=aff_img, + image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]] + ) + return result + + +def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: + detection_generator = batch_detection(images, model, processor, batch_size=batch_size) + + results = [] + max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) + parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH + + if parallelize: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + for preds, orig_sizes in detection_generator: + batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes)) + results.extend(batch_results) + else: + for preds, orig_sizes in detection_generator: + for pred, orig_size in zip(preds, orig_sizes): + results.append(parallel_get_lines(pred, orig_size)) + + return results + + diff --git a/surya/input/langs.py b/surya/input/langs.py new file mode 100644 index 0000000000000000000000000000000000000000..e347408ff7c6adc9ca62a48ba10c60057b70b9b7 --- /dev/null +++ b/surya/input/langs.py @@ -0,0 +1,19 @@ +from typing import List +from surya.languages import LANGUAGE_TO_CODE, CODE_TO_LANGUAGE + + +def replace_lang_with_code(langs: List[str]): + for i in range(len(langs)): + if langs[i].title() in LANGUAGE_TO_CODE: + langs[i] = LANGUAGE_TO_CODE[langs[i].title()] + if langs[i] not in CODE_TO_LANGUAGE: + raise ValueError(f"Language code {langs[i]} not found.") + + +def get_unique_langs(langs: List[List[str]]): + uniques = [] + for lang_list in langs: + for lang in lang_list: + if lang not in uniques: + uniques.append(lang) + return uniques \ No newline at end of file diff --git a/surya/input/load.py b/surya/input/load.py new file mode 100644 index 0000000000000000000000000000000000000000..87a71450c1dbc9e92f7157e6c74b775ba0346d0b --- /dev/null +++ b/surya/input/load.py @@ -0,0 +1,87 @@ +import PIL + +from surya.input.processing import open_pdf, get_page_images +from surya.settings import settings +import os +import filetype +from PIL import Image +import json + + + +def get_name_from_path(path): + return os.path.basename(path).split(".")[0] + + +def load_pdf(pdf_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False): + doc = open_pdf(pdf_path) + last_page = len(doc) + + if start_page: + assert start_page < last_page and start_page >= 0, f"Start page must be between 0 and {last_page}" + else: + start_page = 0 + + if max_pages: + assert max_pages >= 0, f"Max pages must be greater than 0" + last_page = min(start_page + max_pages, last_page) + + page_indices = list(range(start_page, last_page)) + images = get_page_images(doc, page_indices, dpi=dpi) + text_lines = None + if load_text_lines: + from surya.input.pdflines import get_page_text_lines # Putting import here because pypdfium2 causes warnings if its not the top import + text_lines = get_page_text_lines( + pdf_path, + page_indices, + [i.size for i in images] + ) + doc.close() + names = [get_name_from_path(pdf_path) for _ in page_indices] + return images, names, text_lines + + +def load_image(image_path): + image = Image.open(image_path).convert("RGB") + name = get_name_from_path(image_path) + return [image], [name], [None] + + +def load_from_file(input_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False): + input_type = filetype.guess(input_path) + if input_type.extension == "pdf": + return load_pdf(input_path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines) + else: + return load_image(input_path) + + +def load_from_folder(folder_path, max_pages=None, start_page=None, dpi=settings.IMAGE_DPI, load_text_lines=False): + image_paths = [os.path.join(folder_path, image_name) for image_name in os.listdir(folder_path) if not image_name.startswith(".")] + image_paths = [ip for ip in image_paths if not os.path.isdir(ip)] + + images = [] + names = [] + text_lines = [] + for path in image_paths: + extension = filetype.guess(path) + if extension and extension.extension == "pdf": + image, name, text_line = load_pdf(path, max_pages, start_page, dpi=dpi, load_text_lines=load_text_lines) + images.extend(image) + names.extend(name) + text_lines.extend(text_line) + else: + try: + image, name, text_line = load_image(path) + images.extend(image) + names.extend(name) + text_lines.extend(text_line) + except PIL.UnidentifiedImageError: + print(f"Could not load image {path}") + continue + return images, names, text_lines + + +def load_lang_file(lang_path, names): + with open(lang_path, "r") as f: + lang_dict = json.load(f) + return [lang_dict[name].copy() for name in names] diff --git a/surya/input/pdflines.py b/surya/input/pdflines.py new file mode 100644 index 0000000000000000000000000000000000000000..8f36aefa171aeedee2578380954ccfb0e98f0090 --- /dev/null +++ b/surya/input/pdflines.py @@ -0,0 +1,86 @@ +from pdftext.extraction import dictionary_output + +from surya.postprocessing.text import sort_text_lines +from surya.schema import PolygonBox + + +def get_page_text_lines(filepath: str, page_idxs: list, out_sizes: list) -> list: + assert len(page_idxs) == len(out_sizes) + pages_text = dictionary_output(filepath, sort=False, page_range=page_idxs, keep_chars=True) + for full_text, out_size in zip(pages_text, out_sizes): + width = full_text["width"] + height = full_text["height"] + text_w_scale = out_size[0] / width + text_h_scale = out_size[1] / height + for block in full_text["blocks"]: + for line in block["lines"]: + line["bbox"] = [line["bbox"][0] * text_w_scale, line["bbox"][1] * text_h_scale, + line["bbox"][2] * text_w_scale, line["bbox"][3] * text_h_scale] + for span in line["spans"]: + for char in span["chars"]: + char["bbox"] = [char["bbox"][0] * text_w_scale, char["bbox"][1] * text_h_scale, + char["bbox"][2] * text_w_scale, char["bbox"][3] * text_h_scale] + return pages_text + + +def get_table_blocks(tables: list, full_text: dict, img_size: list, table_thresh=.8): + # Returns coordinates relative to input table, not full image + table_texts = [] + for table in tables: + table_poly = PolygonBox(polygon=[ + [table[0], table[1]], + [table[2], table[1]], + [table[2], table[3]], + [table[0], table[3]] + ]) + table_text = [] + rotation = full_text["rotation"] + for block in full_text["blocks"]: + for line in block["lines"]: + line_poly = PolygonBox(polygon=[ + [line["bbox"][0], line["bbox"][1]], + [line["bbox"][2], line["bbox"][1]], + [line["bbox"][2], line["bbox"][3]], + [line["bbox"][0], line["bbox"][3]] + ]) + if line_poly.intersection_pct(table_poly) < table_thresh: + continue + curr_span = None + curr_box = None + for span in line["spans"]: + for char in span["chars"]: + same_span = False + if curr_span: + if rotation == 90: + same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][1] - curr_box[3]) / img_size[1] < 0.01 + elif rotation == 180: + same_span = (char["bbox"][2] - curr_box[0]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01 + elif rotation == 270: + same_span = (char["bbox"][0] - curr_box[0]) / img_size[0] < 0.01 and abs(char["bbox"][3] - curr_box[1]) / img_size[1] < 0.01 + else: + same_span = (char["bbox"][0] - curr_box[2]) / img_size[0] < 0.01 and (char["bbox"][1] - curr_box[1]) / img_size[1] < 0.01 + + if curr_span is None: + curr_span = char["char"] + curr_box = char["bbox"] + elif same_span: + curr_span += char["char"] + curr_box = [min(curr_box[0], char["bbox"][0]), min(curr_box[1], char["bbox"][1]), + max(curr_box[2], char["bbox"][2]), max(curr_box[3], char["bbox"][3])] + else: + table_text.append({"text": curr_span, "bbox": curr_box}) + curr_span = char["char"] + curr_box = char["bbox"] + if curr_span is not None: + table_text.append({"text": curr_span, "bbox": curr_box}) + # Adjust to be relative to input table + for item in table_text: + item["bbox"] = [ + item["bbox"][0] - table[0], + item["bbox"][1] - table[1], + item["bbox"][2] - table[0], + item["bbox"][3] - table[1] + ] + table_text = sort_text_lines(table_text) + table_texts.append(table_text) + return table_texts \ No newline at end of file diff --git a/surya/input/processing.py b/surya/input/processing.py new file mode 100644 index 0000000000000000000000000000000000000000..ea8e5fea5fcbb0b4a908d0d0918f3b9fa73246f7 --- /dev/null +++ b/surya/input/processing.py @@ -0,0 +1,118 @@ +from typing import List + +import cv2 +import numpy as np +import math +import pypdfium2 +from PIL import Image, ImageOps, ImageDraw +import torch +from surya.settings import settings + + +def convert_if_not_rgb(images: List[Image.Image]) -> List[Image.Image]: + new_images = [] + for image in images: + if image.mode != "RGB": + image = image.convert("RGB") + new_images.append(image) + return new_images + + +def get_total_splits(image_size, processor): + img_height = list(image_size)[1] + max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT + processor_height = processor.size["height"] + if img_height > max_height: + num_splits = math.ceil(img_height / processor_height) + return num_splits + return 1 + + +def split_image(img, processor): + # This will not modify/return the original image - it will either crop, or copy the image + img_height = list(img.size)[1] + max_height = settings.DETECTOR_IMAGE_CHUNK_HEIGHT + processor_height = processor.size["height"] + if img_height > max_height: + num_splits = math.ceil(img_height / processor_height) + splits = [] + split_heights = [] + for i in range(num_splits): + top = i * processor_height + bottom = (i + 1) * processor_height + if bottom > img_height: + bottom = img_height + cropped = img.crop((0, top, img.size[0], bottom)) + height = bottom - top + if height < processor_height: + cropped = ImageOps.pad(cropped, (img.size[0], processor_height), color=255, centering=(0, 0)) + splits.append(cropped) + split_heights.append(height) + return splits, split_heights + return [img.copy()], [img_height] + + +def prepare_image_detection(img, processor): + new_size = (processor.size["width"], processor.size["height"]) + + # This double resize actually necessary for downstream accuracy + img.thumbnail(new_size, Image.Resampling.LANCZOS) + img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size + + img = np.asarray(img, dtype=np.uint8) + img = processor(img)["pixel_values"][0] + img = torch.from_numpy(img) + return img + + +def open_pdf(pdf_filepath): + return pypdfium2.PdfDocument(pdf_filepath) + + +def get_page_images(doc, indices: List, dpi=settings.IMAGE_DPI): + renderer = doc.render( + pypdfium2.PdfBitmap.to_pil, + page_indices=indices, + scale=dpi / 72, + ) + images = list(renderer) + images = [image.convert("RGB") for image in images] + return images + + +def slice_bboxes_from_image(image: Image.Image, bboxes): + lines = [] + for bbox in bboxes: + line = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) + if line.size[0] == 0: + print(f"Warning: found an empty line with bbox {bbox}") + lines.append(line) + return lines + + +def slice_polys_from_image(image: Image.Image, polys): + image_array = np.array(image, dtype=np.uint8) + lines = [] + for idx, poly in enumerate(polys): + lines.append(slice_and_pad_poly(image_array, poly)) + return lines + + +def slice_and_pad_poly(image_array: np.array, coordinates): + # Draw polygon onto mask + coordinates = [(corner[0], corner[1]) for corner in coordinates] + bbox = [min([x[0] for x in coordinates]), min([x[1] for x in coordinates]), max([x[0] for x in coordinates]), max([x[1] for x in coordinates])] + + # We mask out anything not in the polygon + cropped_polygon = image_array[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy() + coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates] + + # Pad the area outside the polygon with the pad value + mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8) + cv2.fillPoly(mask, [np.int32(coordinates)], 1) + mask = np.stack([mask] * 3, axis=-1) + + cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE + rectangle_image = Image.fromarray(cropped_polygon) + + return rectangle_image diff --git a/surya/languages.py b/surya/languages.py new file mode 100644 index 0000000000000000000000000000000000000000..d4bfbd48602c4b595602d9d20dd2554a5b767434 --- /dev/null +++ b/surya/languages.py @@ -0,0 +1,102 @@ +CODE_TO_LANGUAGE = { + "_math": "Math", + 'af': 'Afrikaans', + 'am': 'Amharic', + 'ar': 'Arabic', + 'as': 'Assamese', + 'az': 'Azerbaijani', + 'be': 'Belarusian', + 'bg': 'Bulgarian', + 'bn': 'Bengali', + 'br': 'Breton', + 'bs': 'Bosnian', + 'ca': 'Catalan', + 'cs': 'Czech', + 'cy': 'Welsh', + 'da': 'Danish', + 'de': 'German', + 'el': 'Greek', + 'en': 'English', + 'eo': 'Esperanto', + 'es': 'Spanish', + 'et': 'Estonian', + 'eu': 'Basque', + 'fa': 'Persian', + 'fi': 'Finnish', + 'fr': 'French', + 'fy': 'Western Frisian', + 'ga': 'Irish', + 'gd': 'Scottish Gaelic', + 'gl': 'Galician', + 'gu': 'Gujarati', + 'ha': 'Hausa', + 'he': 'Hebrew', + 'hi': 'Hindi', + 'hr': 'Croatian', + 'hu': 'Hungarian', + 'hy': 'Armenian', + 'id': 'Indonesian', + 'is': 'Icelandic', + 'it': 'Italian', + 'ja': 'Japanese', + 'jv': 'Javanese', + 'ka': 'Georgian', + 'kk': 'Kazakh', + 'km': 'Khmer', + 'kn': 'Kannada', + 'ko': 'Korean', + 'ku': 'Kurdish', + 'ky': 'Kyrgyz', + 'la': 'Latin', + 'lo': 'Lao', + 'lt': 'Lithuanian', + 'lv': 'Latvian', + 'mg': 'Malagasy', + 'mk': 'Macedonian', + 'ml': 'Malayalam', + 'mn': 'Mongolian', + 'mr': 'Marathi', + 'ms': 'Malay', + 'my': 'Burmese', + 'ne': 'Nepali', + 'nl': 'Dutch', + 'no': 'Norwegian', + 'om': 'Oromo', + 'or': 'Oriya', + 'pa': 'Punjabi', + 'pl': 'Polish', + 'ps': 'Pashto', + 'pt': 'Portuguese', + 'ro': 'Romanian', + 'ru': 'Russian', + 'sa': 'Sanskrit', + 'sd': 'Sindhi', + 'si': 'Sinhala', + 'sk': 'Slovak', + 'sl': 'Slovenian', + 'so': 'Somali', + 'sq': 'Albanian', + 'sr': 'Serbian', + 'su': 'Sundanese', + 'sv': 'Swedish', + 'sw': 'Swahili', + 'ta': 'Tamil', + 'te': 'Telugu', + 'th': 'Thai', + 'tl': 'Tagalog', + 'tr': 'Turkish', + 'ug': 'Uyghur', + 'uk': 'Ukrainian', + 'ur': 'Urdu', + 'uz': 'Uzbek', + 'vi': 'Vietnamese', + 'xh': 'Xhosa', + 'yi': 'Yiddish', + 'zh': 'Chinese', +} + +LANGUAGE_TO_CODE = {v: k for k, v in CODE_TO_LANGUAGE.items()} + + +def is_arabic(lang_code): + return lang_code in ["ar", "fa", "ps", "ug", "ur"] diff --git a/surya/layout.py b/surya/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7dd4df584fd85768a9c535a8f04e9abbdf98cb --- /dev/null +++ b/surya/layout.py @@ -0,0 +1,229 @@ +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor +from typing import List, Optional +from PIL import Image +import numpy as np + +from surya.detection import batch_detection +from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes +from surya.schema import LayoutResult, LayoutBox, TextDetectionResult +from surya.settings import settings + + +def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: + logits = np.stack(heatmaps, axis=0) + vertical_line_bboxes = detection_result.vertical_lines + line_bboxes = detection_result.bboxes + + # Scale back to processor size + for line in vertical_line_bboxes: + line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape))) + + for line in line_bboxes: + line.rescale(orig_size, list(reversed(heatmaps[0].shape))) + + for bbox in vertical_line_bboxes: + # Give some width to the vertical lines + vert_bbox = list(bbox.bbox) + vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width) + + logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0 # zero out where the column lines are + + logits[:, logits[0] >= .5] = 0 # zero out where blanks are + + # Zero out where other segments are + for i in range(logits.shape[0]): + logits[i, segment_assignment != i] = 0 + + detected_boxes = [] + for heatmap_idx in range(1, len(id2label)): # Skip the blank class + heatmap = logits[heatmap_idx] + if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: + continue + bboxes = get_detected_boxes(heatmap) + bboxes = [bbox for bbox in bboxes if bbox.area > 25] + for bb in bboxes: + bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1]) + + for bbox in bboxes: + detected_boxes.append(LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=1)) + + detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True) + # Expand bbox to cover intersecting lines + box_lines = defaultdict(list) + used_lines = set() + + # We try 2 rounds of identifying the correct lines to snap to + # First round is majority intersection, second lowers the threshold + for thresh in [.5, .4]: + for bbox_idx, bbox in enumerate(detected_boxes): + for line_idx, line_bbox in enumerate(line_bboxes): + if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines: + box_lines[bbox_idx].append(line_bbox.bbox) + used_lines.add(line_idx) + + new_boxes = [] + for bbox_idx, bbox in enumerate(detected_boxes): + if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures + continue + + # Skip if we didn't find any lines to snap to, except for Pictures and Formulas + if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]: + continue + + covered_lines = box_lines[bbox_idx] + # Snap non-picture layout boxes to correct text boundaries + if len(covered_lines) > 0 and bbox.label not in ["Picture"]: + min_x = min([line[0] for line in covered_lines]) + min_y = min([line[1] for line in covered_lines]) + max_x = max([line[2] for line in covered_lines]) + max_y = max([line[3] for line in covered_lines]) + + # Tables and formulas can contain text, but text isn't the whole area + if bbox.label in ["Table", "Formula"]: + min_x_box = min([b[0] for b in bbox.polygon]) + min_y_box = min([b[1] for b in bbox.polygon]) + max_x_box = max([b[0] for b in bbox.polygon]) + max_y_box = max([b[1] for b in bbox.polygon]) + + min_x = min(min_x, min_x_box) + min_y = min(min_y, min_y_box) + max_x = max(max_x, max_x_box) + max_y = max(max_y, max_y_box) + + bbox.polygon = [ + [min_x, min_y], + [max_x, min_y], + [max_x, max_y], + [min_x, max_y] + ] + + if bbox_idx in box_lines and bbox.label in ["Picture"]: + bbox.label = "Figure" + + new_boxes.append(bbox) + + # Merge tables together (sometimes one column is detected as a separate table) + mergeable_types = ["Table", "Picture", "Figure"] + for ftype in mergeable_types: + to_remove = set() + for bbox_idx, bbox in enumerate(new_boxes): + if bbox.label != ftype or bbox_idx in to_remove: + continue + + for bbox_idx2, bbox2 in enumerate(new_boxes): + if bbox2.label != ftype or bbox_idx2 in to_remove or bbox_idx == bbox_idx2: + continue + + if bbox.intersection_pct(bbox2, x_margin=.25) > .1: + bbox.merge(bbox2) + to_remove.add(bbox_idx2) + + new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove] + + # Ensure we account for all text lines in the layout + unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines] + for bbox in unused_lines: + new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5)) + + for bbox in new_boxes: + bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size) + + detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16] + + # Remove bboxes contained inside others, unless they're captions + contained_bbox = [] + for i, bbox in enumerate(detected_boxes): + for j, bbox2 in enumerate(detected_boxes): + if i == j: + continue + + if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]: + contained_bbox.append(j) + + detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox] + + return detected_boxes + + +def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]: + bboxes = [] + for i in range(1, len(id2label)): # Skip the blank class + heatmap = heatmaps[i] + assert heatmap.shape == segment_assignment.shape + heatmap[segment_assignment != i] = 0 # zero out where another segment is + + # Skip processing empty labels + if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: + continue + + bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size) + for bb in bbox: + bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i])) + + bboxes = keep_largest_boxes(bboxes) + return bboxes + + +def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult: + logits = np.stack(heatmaps, axis=0) + segment_assignment = logits.argmax(axis=0) + if detection_results is not None: + bboxes = get_regions_from_detection_result(detection_results, heatmaps, orig_size, id2label, + segment_assignment) + else: + bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment) + + segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8)) + + result = LayoutResult( + bboxes=bboxes, + segmentation_map=segmentation_img, + heatmaps=heatmaps, + image_bbox=[0, 0, orig_size[0], orig_size[1]] + ) + + return result + + +def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]: + layout_generator = batch_detection(images, model, processor, batch_size=batch_size) + id2label = model.config.id2label + + results = [] + max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) + parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH + + if parallelize: + with ProcessPoolExecutor(max_workers=max_workers) as executor: + img_idx = 0 + for preds, orig_sizes in layout_generator: + futures = [] + for pred, orig_size in zip(preds, orig_sizes): + future = executor.submit( + parallel_get_regions, + pred, + orig_size, + id2label, + detection_results[img_idx] if detection_results else None + ) + + futures.append(future) + img_idx += 1 + + for future in futures: + results.append(future.result()) + else: + img_idx = 0 + for preds, orig_sizes in layout_generator: + for pred, orig_size in zip(preds, orig_sizes): + results.append(parallel_get_regions( + pred, + orig_size, + id2label, + detection_results[img_idx] if detection_results else None + )) + + img_idx += 1 + + return results \ No newline at end of file diff --git a/surya/model/detection/config.py b/surya/model/detection/config.py new file mode 100644 index 0000000000000000000000000000000000000000..bdbe0a16f8c2e04dc3a74eb10a18894a3225e28b --- /dev/null +++ b/surya/model/detection/config.py @@ -0,0 +1,51 @@ +from transformers import PretrainedConfig + + +class EfficientViTConfig(PretrainedConfig): + r""" + ```""" + + model_type = "efficientvit" + + def __init__( + self, + num_classes=2, + num_channels=3, + widths=(32, 64, 128, 256, 512), + head_dim=32, + num_stages=4, + depths=(1, 1, 1, 6, 6), + strides=(2, 2, 2, 2, 2), + hidden_sizes=(32, 64, 160, 256), + patch_size=(7, 7), + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + classifier_dropout_prob=0.0, + layer_norm_eps=1e-6, + decoder_layer_hidden_size=128, + decoder_hidden_size=512, + semantic_loss_ignore_index=255, + initializer_range=0.02, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_classes = num_classes + self.widths = widths + self.head_dim = head_dim + + self.num_channels = num_channels + self.num_stages = num_stages + self.depths = depths + self.strides = strides + self.hidden_sizes = hidden_sizes + self.patch_size = patch_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.classifier_dropout_prob = classifier_dropout_prob + self.layer_norm_eps = layer_norm_eps + self.decoder_hidden_size = decoder_hidden_size + self.decoder_layer_hidden_size = decoder_layer_hidden_size + self.semantic_loss_ignore_index = semantic_loss_ignore_index + + self.initializer_range = initializer_range \ No newline at end of file diff --git a/surya/model/detection/model.py b/surya/model/detection/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7199a294f6e05f1de663436d46a91c1d72822730 --- /dev/null +++ b/surya/model/detection/model.py @@ -0,0 +1,767 @@ +""" +This is an implementation of efficientvit, with some modifications (decode head, etc). + +Original paper at https://arxiv.org/abs/2205.14756 + +Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py +Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit +""" + +from typing import Optional, Union, Tuple +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import PreTrainedModel +from transformers.modeling_outputs import SemanticSegmenterOutput + +from surya.model.detection.config import EfficientViTConfig +from surya.model.detection.processor import SegformerImageProcessor +from surya.settings import settings + + +def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): + config = EfficientViTConfig.from_pretrained(checkpoint) + model = EfficientViTForSemanticSegmentation.from_pretrained(checkpoint, torch_dtype=dtype, config=config, ignore_mismatched_sizes=True) + model = model.to(device) + model = model.eval() + print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}") + return model + + +def load_processor(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT): + processor = SegformerImageProcessor.from_pretrained(checkpoint) + return processor + + +def val2list(x: list or tuple or any, repeat_time=1): + if isinstance(x, (list, tuple)): + return list(x) + return [x for _ in range(repeat_time)] + + +def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1): + # repeat elements if necessary + x = val2list(x) + if len(x) > 0: + x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] + + return tuple(x) + + +def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]: + if isinstance(kernel_size, tuple): + return tuple([get_same_padding(ks) for ks in kernel_size]) + else: + assert kernel_size % 2 > 0, "kernel size should be odd number" + return kernel_size // 2 + + +def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + +class ConvNormAct(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=False, + dropout=0., + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(ConvNormAct, self).__init__() + self.dropout = nn.Dropout(dropout, inplace=False) + padding = get_padding(kernel_size, stride, dilation) + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + bias=bias, + padding=padding, + ) + self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity() + self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity() + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x + + +class DSConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, None), + ): + super(DSConv, self).__init__() + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + + self.depth_conv = ConvNormAct( + in_channels, + in_channels, + kernel_size, + stride, + groups=in_channels, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.point_conv = ConvNormAct( + in_channels, + out_channels, + 1, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + + def forward(self, x): + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=1, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, None), + ): + super(ConvBlock, self).__init__() + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.conv1 = ConvNormAct( + in_channels, + mid_channels, + kernel_size, + stride, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.conv2 = ConvNormAct( + mid_channels, + out_channels, + kernel_size, + 1, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class MBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, nn.ReLU6, None), + ): + super(MBConv, self).__init__() + use_bias = val2tuple(use_bias, 3) + norm_layer = val2tuple(norm_layer, 3) + act_layer = val2tuple(act_layer, 3) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.inverted_conv = ConvNormAct( + in_channels, + mid_channels, + 1, + stride=1, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.depth_conv = ConvNormAct( + mid_channels, + mid_channels, + kernel_size, + stride=stride, + groups=mid_channels, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + self.point_conv = ConvNormAct( + mid_channels, + out_channels, + 1, + norm_layer=norm_layer[2], + act_layer=act_layer[2], + bias=use_bias[2], + ) + + def forward(self, x): + x = self.inverted_conv(x) + x = self.depth_conv(x) + x = self.point_conv(x) + return x + + +class FusedMBConv(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + stride=1, + mid_channels=None, + expand_ratio=6, + groups=1, + use_bias=False, + norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), + act_layer=(nn.ReLU6, None), + ): + super(FusedMBConv, self).__init__() + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + mid_channels = mid_channels or round(in_channels * expand_ratio) + + self.spatial_conv = ConvNormAct( + in_channels, + mid_channels, + kernel_size, + stride=stride, + groups=groups, + norm_layer=norm_layer[0], + act_layer=act_layer[0], + bias=use_bias[0], + ) + self.point_conv = ConvNormAct( + mid_channels, + out_channels, + 1, + norm_layer=norm_layer[1], + act_layer=act_layer[1], + bias=use_bias[1], + ) + + def forward(self, x): + x = self.spatial_conv(x) + x = self.point_conv(x) + return x + + +class LiteMLA(nn.Module): + """Lightweight multi-scale linear attention""" + + def __init__( + self, + in_channels: int, + out_channels: int, + heads: int or None = None, + heads_ratio: float = 1.0, + dim=8, + use_bias=False, + norm_layer=(None, nn.BatchNorm2d), + act_layer=(None, None), + kernel_func=nn.ReLU, + scales=(5,), + eps=1e-5, + ): + super(LiteMLA, self).__init__() + self.eps = eps + heads = heads or int(in_channels // dim * heads_ratio) + total_dim = heads * dim + use_bias = val2tuple(use_bias, 2) + norm_layer = val2tuple(norm_layer, 2) + act_layer = val2tuple(act_layer, 2) + + self.dim = dim + self.qkv = ConvNormAct( + in_channels, + 3 * total_dim, + 1, + bias=use_bias[0], + norm_layer=norm_layer[0], + act_layer=act_layer[0], + ) + self.aggreg = nn.ModuleList([ + nn.Sequential( + nn.Conv2d( + 3 * total_dim, + 3 * total_dim, + scale, + padding=get_same_padding(scale), + groups=3 * total_dim, + bias=use_bias[0], + ), + nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), + ) + for scale in scales + ]) + self.kernel_func = kernel_func(inplace=False) + + self.proj = ConvNormAct( + total_dim * (1 + len(scales)), + out_channels, + 1, + bias=use_bias[1], + norm_layer=norm_layer[1], + act_layer=act_layer[1], + ) + + def _attn(self, q, k, v): + dtype = v.dtype + q, k, v = q.float(), k.float(), v.float() + kv = k.transpose(-1, -2) @ v + out = q @ kv + out = out[..., :-1] / (out[..., -1:] + self.eps) + return out.to(dtype) + + def forward(self, x): + # Shape is B, C, H, W + B, _, H, W = x.shape + + # generate multi-scale q, k, v + qkv = self.qkv(x) + multi_scale_qkv = [qkv] + for op in self.aggreg: + multi_scale_qkv.append(op(qkv)) + multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) + multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2) + # Shape for each is B, C, HW, head_dim + q, k, v = multi_scale_qkv.chunk(3, dim=-1) + + # lightweight global attention + q = self.kernel_func(q) + k = self.kernel_func(k) + v = F.pad(v, (0, 1), mode="constant", value=1.) + + out = self._attn(q, k, v) + + # final projection + out = out.transpose(-1, -2).reshape(B, -1, H, W) + out = self.proj(out) + return out + + +class EfficientVitBlock(nn.Module): + def __init__( + self, + in_channels, + heads_ratio=1.0, + head_dim=32, + expand_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + ): + super(EfficientVitBlock, self).__init__() + self.context_module = ResidualBlock( + LiteMLA( + in_channels=in_channels, + out_channels=in_channels, + heads_ratio=heads_ratio, + dim=head_dim, + norm_layer=(None, norm_layer), + ), + nn.Identity(), + ) + self.local_module = ResidualBlock( + MBConv( + in_channels=in_channels, + out_channels=in_channels, + expand_ratio=expand_ratio, + use_bias=(True, True, False), + norm_layer=(None, None, norm_layer), + act_layer=(act_layer, act_layer, None), + ), + nn.Identity(), + ) + + def forward(self, x): + x = self.context_module(x) + x = self.local_module(x) + return x + + +class ResidualBlock(nn.Module): + def __init__( + self, + main: Optional[nn.Module], + shortcut: Optional[nn.Module] = None, + pre_norm: Optional[nn.Module] = None, + ): + super(ResidualBlock, self).__init__() + self.pre_norm = pre_norm if pre_norm is not None else nn.Identity() + self.main = main + self.shortcut = shortcut + + def forward(self, x): + res = self.main(self.pre_norm(x)) + if self.shortcut is not None: + res = res + self.shortcut(x) + return res + + +def build_local_block( + in_channels: int, + out_channels: int, + stride: int, + kernel_size: int, + expand_ratio: float, + norm_layer: str, + act_layer: str, + fewer_norm: bool = False, + block_type: str = "default", +): + assert block_type in ["default", "large", "fused"] + if expand_ratio == 1: + if block_type == "default": + block = DSConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + kernel_size=kernel_size, + use_bias=(True, False) if fewer_norm else False, + norm_layer=(None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, None), + ) + else: + block = ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + kernel_size=kernel_size, + use_bias=(True, False) if fewer_norm else False, + norm_layer=(None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, None), + ) + else: + if block_type == "default": + block = MBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + use_bias=(True, True, False) if fewer_norm else False, + norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, act_layer, None), + ) + else: + block = FusedMBConv( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + use_bias=(True, False) if fewer_norm else False, + norm_layer=(None, norm_layer) if fewer_norm else norm_layer, + act_layer=(act_layer, None), + ) + return block + + +class Stem(nn.Sequential): + def __init__(self, in_chs, out_chs, depth, stride, norm_layer, act_layer, block_type='default'): + super().__init__() + self.stride = stride + + self.add_module( + 'in_conv', + ConvNormAct( + in_chs, out_chs, + kernel_size=stride + 1, stride=stride, norm_layer=norm_layer, act_layer=act_layer, + ) + ) + stem_block = 0 + for _ in range(depth): + self.add_module(f'res{stem_block}', ResidualBlock( + build_local_block( + in_channels=out_chs, + out_channels=out_chs, + stride=1, + kernel_size=3, + expand_ratio=1, + norm_layer=norm_layer, + act_layer=act_layer, + block_type=block_type, + ), + nn.Identity(), + )) + stem_block += 1 + + +class EfficientVitLargeStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + depth, + stride, + norm_layer, + act_layer, + head_dim, + vit_stage=False, + fewer_norm=False, + ): + super(EfficientVitLargeStage, self).__init__() + blocks = [ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=stride, + kernel_size=stride + 1, + expand_ratio=24 if vit_stage else 16, + norm_layer=norm_layer, + act_layer=act_layer, + fewer_norm=vit_stage or fewer_norm, + block_type='default' if fewer_norm else 'fused', + ), + None, + )] + in_chs = out_chs + + if vit_stage: + # for stage 4 + for _ in range(depth): + blocks.append( + EfficientVitBlock( + in_channels=in_chs, + head_dim=head_dim, + expand_ratio=6, + norm_layer=norm_layer, + act_layer=act_layer, + ) + ) + else: + # for stage 1, 2, 3 + for i in range(depth): + blocks.append(ResidualBlock( + build_local_block( + in_channels=in_chs, + out_channels=out_chs, + stride=1, + kernel_size=3, + expand_ratio=4, + norm_layer=norm_layer, + act_layer=act_layer, + fewer_norm=fewer_norm, + block_type='default' if fewer_norm else 'fused', + ), + nn.Identity(), + )) + + self.blocks = nn.Sequential(*blocks) + + def forward(self, x): + return self.blocks(x) + + +class EfficientVitLarge(nn.Module): + def __init__( + self, + config: EfficientViTConfig, + norm_layer=nn.BatchNorm2d, + act_layer=nn.Hardswish, + ): + super(EfficientVitLarge, self).__init__() + self.grad_checkpointing = False + self.num_classes = config.num_classes + self.norm_eps = config.layer_norm_eps + norm_layer = partial(norm_layer, eps=self.norm_eps) + + # input stem + self.stem = Stem(config.num_channels, config.widths[0], config.depths[0], config.strides[0], norm_layer, act_layer, block_type='large') + stride = config.strides[0] + + # stages + self.feature_info = [] + self.stages = nn.Sequential() + in_channels = config.widths[0] + for i, (w, d, s) in enumerate(zip(config.widths[1:], config.depths[1:], config.strides[1:])): + self.stages.append(EfficientVitLargeStage( + in_channels, + w, + depth=d, + stride=s, + norm_layer=norm_layer, + act_layer=act_layer, + head_dim=config.head_dim, + vit_stage=i >= 3, + fewer_norm=i >= 2, + )) + stride *= s + in_channels = w + self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')] + + self.num_features = in_channels + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + def forward(self, x): + x = self.stem(x) + encoder_hidden_states = [] + for i, module in enumerate(self.stages): + x = module(x) + encoder_hidden_states.append(x) + + return encoder_hidden_states + + +class EfficientViTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = EfficientViTConfig + base_model_prefix = "efficientvit" + main_input_name = "pixel_values" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class DecodeMLP(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.proj = nn.Linear(input_dim, output_dim) + + def forward(self, hidden_states: torch.Tensor): + # Input is B, C, H, W + hidden_states = hidden_states.flatten(2).transpose(1, 2) + # Output is B, HW, C + hidden_states = self.proj(hidden_states) + return hidden_states + + +class DecodeHead(EfficientViTPreTrainedModel): + def __init__(self, config: EfficientViTConfig): + super().__init__(config) + + # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size + mlps = [] + for width in config.widths[1:]: + mlp = DecodeMLP(input_dim=width, output_dim=config.decoder_layer_hidden_size) + mlps.append(mlp) + self.linear_c = nn.ModuleList(mlps) + + # the following 3 layers implement the ConvModule of the original implementation + self.linear_fuse = nn.Conv2d( + in_channels=config.decoder_layer_hidden_size * config.num_stages, + out_channels=config.decoder_hidden_size, + kernel_size=1, + bias=False, + ) + self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) + self.activation = nn.ReLU() + + self.dropout = nn.Dropout(config.classifier_dropout_prob) + self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1) + + self.config = config + + def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: + batch_size = encoder_hidden_states[-1].shape[0] + + all_hidden_states = () + for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): + height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] + encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C + # Permute to B, C, HW + encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) + encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) + # upsample + encoder_hidden_state = nn.functional.interpolate( + encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False + ) + all_hidden_states += (encoder_hidden_state,) + + hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + + # logits are of shape (batch_size, num_labels, height/4, width/4) + logits = self.classifier(hidden_states) + + return logits + + +class EfficientViTForSemanticSegmentation(EfficientViTPreTrainedModel): + def __init__(self, config, **kwargs): + super().__init__(config) + self.vit = EfficientVitLarge(config) + self.decode_head = DecodeHead(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + pixel_values: torch.FloatTensor + ) -> Union[Tuple, SemanticSegmenterOutput]: + + # Pixel values should be B,C,H,W + encoder_hidden_states = self.vit( + pixel_values, + ) + + logits = self.decode_head(encoder_hidden_states) + + # Apply sigmoid to get 0-1 output + logits = torch.special.expit(logits) + + return SemanticSegmenterOutput( + loss=None, + logits=logits, + hidden_states=encoder_hidden_states + ) \ No newline at end of file diff --git a/surya/model/detection/processor.py b/surya/model/detection/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..822d7d152b8032bca0c9e1642e8c017ca8067dc0 --- /dev/null +++ b/surya/model/detection/processor.py @@ -0,0 +1,284 @@ +import warnings +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from transformers.image_transforms import to_channel_dimension_format +from transformers.image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + make_list_of_images, +) +from transformers.utils import TensorType + + +import PIL.Image +import torch + + +class SegformerImageProcessor(BaseImageProcessor): + r""" + Constructs a Segformer image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `(size["height"], + size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"height": 512, "width": 512}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_reduce_labels (`bool`, *optional*, defaults to `False`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is + used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The + background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the + `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_reduce_labels: bool = False, + **kwargs, + ) -> None: + if "reduce_labels" in kwargs: + warnings.warn( + "The `reduce_labels` parameter is deprecated and will be removed in a future version. Please use " + "`do_reduce_labels` instead.", + FutureWarning, + ) + do_reduce_labels = kwargs.pop("reduce_labels") + + super().__init__(**kwargs) + size = size if size is not None else {"height": 512, "width": 512} + size = get_size_dict(size) + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.do_reduce_labels = do_reduce_labels + self._valid_processor_keys = [ + "images", + "segmentation_maps", + "do_resize", + "size", + "resample", + "do_rescale", + "rescale_factor", + "do_normalize", + "image_mean", + "image_std", + "do_reduce_labels", + "return_tensors", + "data_format", + "input_data_format", + ] + + @classmethod + def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs): + """ + Overrides the `from_dict` method from the base class to make sure `do_reduce_labels` is updated if image + processor is created using from_dict and kwargs e.g. `SegformerImageProcessor.from_pretrained(checkpoint, + reduce_labels=True)` + """ + image_processor_dict = image_processor_dict.copy() + if "reduce_labels" in kwargs: + image_processor_dict["reduce_labels"] = kwargs.pop("reduce_labels") + return super().from_dict(image_processor_dict, **kwargs) + + def _preprocess( + self, + image: ImageInput, + do_resize: bool, + do_rescale: bool, + do_normalize: bool, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + rescale_factor: Optional[float] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + return image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + # All transformations expect numpy arrays. + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + image = self._preprocess( + image=image, + do_resize=do_resize, + size=size, + resample=resample, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + input_data_format=input_data_format, + ) + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + def __call__(self, images, segmentation_maps=None, **kwargs): + """ + Preprocesses a batch of images and optionally segmentation maps. + + Overrides the `__call__` method of the `Preprocessor` class so that both images and segmentation maps can be + passed in as positional arguments. + """ + return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) + + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput] = None, + do_resize: Optional[bool] = None, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_reduce_labels: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + segmentation_maps (`ImageInput`, *optional*): + Segmentation map to preprocess. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after `resize` is applied. + resample (`int`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): + Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 + is used for background, and background itself is not included in all classes of a dataset (e.g. + ADE20k). The background label will be replaced by 255. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + resample = resample if resample is not None else self.resample + size = size if size is not None else self.size + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + images = [ + self._preprocess_image( + image=img, + do_resize=do_resize, + resample=resample, + size=size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) \ No newline at end of file diff --git a/surya/model/ordering/config.py b/surya/model/ordering/config.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf20f71e7119022d95021e280b4db0e10bf60a5 --- /dev/null +++ b/surya/model/ordering/config.py @@ -0,0 +1,8 @@ +from transformers import MBartConfig, DonutSwinConfig + + +class MBartOrderConfig(MBartConfig): + pass + +class VariableDonutSwinConfig(DonutSwinConfig): + pass \ No newline at end of file diff --git a/surya/model/ordering/decoder.py b/surya/model/ordering/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..89fc3ebce073ba5c4bef5cd3fd049f657324c3b3 --- /dev/null +++ b/surya/model/ordering/decoder.py @@ -0,0 +1,557 @@ +import copy +from typing import Optional, List, Union, Tuple + +from transformers import MBartForCausalLM, MBartConfig +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask +from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions +from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, MBartLearnedPositionalEmbedding, MBartDecoderLayer +from surya.model.ordering.config import MBartOrderConfig +import torch +import math + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + From llama + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MBartGQAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[MBartConfig] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0, f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})" + assert embed_dim % self.num_kv_heads == 0, f"embed_dim ({self.embed_dim}) must be divisible by num_kv_heads ({self.num_kv_heads})" + + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, self.num_kv_heads * self.head_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def _shape_key_value(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + + # Expand kv heads, then match query shape + key_states = repeat_kv(key_states, self.num_kv_groups) + value_states = repeat_kv(value_states, self.num_kv_groups) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +MBART_ATTENTION_CLASSES = { + "eager": MBartGQAttention, + "flash_attention_2": None +} + + +class MBartOrderDecoderLayer(MBartDecoderLayer): + def __init__(self, config: MBartConfig): + nn.Module.__init__(self) + self.embed_dim = config.d_model + + self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + num_kv_heads=config.kv_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + self.embed_dim, + config.decoder_attention_heads, + num_kv_heads=config.kv_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + +class BboxEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.x1_embed = nn.Embedding(config.max_width, config.d_model) + self.y1_embed = nn.Embedding(config.max_height, config.d_model) + self.x2_embed = nn.Embedding(config.max_width, config.d_model) + self.y2_embed = nn.Embedding(config.max_height, config.d_model) + self.w_embed = nn.Embedding(config.max_width, config.d_model) + self.h_embed = nn.Embedding(config.max_height, config.d_model) + self.cx_embed = nn.Embedding(config.max_width, config.d_model) + self.cy_embed = nn.Embedding(config.max_height, config.d_model) + self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.d_model) + + def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor, past_key_values_length: int): + x1, y1, x2, y2 = boxes.unbind(dim=-1) + # Shape is (batch_size, num_boxes/seq len, d_model) + w = x2 - x1 + h = y2 - y1 + # Center x and y in torch long tensors + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + cx = cx.long() + cy = cy.long() + + coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2) + embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + + # Add in positional embeddings for the boxes + if past_key_values_length == 0: + for j in range(embedded.shape[0]): + box_start = input_box_counts[j, 0] + box_end = input_box_counts[j, 1] - 1 # Skip the sep token + box_count = box_end - box_start + embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count] + + return embedded + + +class MBartOrderDecoder(MBartDecoder): + def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): + MBartPreTrainedModel.__init__(self, config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BboxEmbedding(config) if embed_tokens is None else embed_tokens + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = MBartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + # Language-specific MoE goes at second and second-to-last layer + self.layers = nn.ModuleList([MBartOrderDecoderLayer(config) for _ in range(config.decoder_layers)]) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_boxes: torch.LongTensor = None, + input_boxes_mask: Optional[torch.Tensor] = None, + input_boxes_counts: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_boxes is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_boxes is not None: + input = input_boxes + input_shape = input_boxes.size()[:-1] # Shape (batch_size, num_boxes) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_boxes, input_boxes_counts, past_key_values_length) * self.embed_scale + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = input_boxes_mask if (input_boxes_mask is not None and 0 in input_boxes_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + input_boxes_mask, input_shape, inputs_embeds, past_key_values_length + ) + + if past_key_values_length == 0: + box_ends = input_boxes_counts[:, 1] + box_starts = input_boxes_counts[:, 0] + input_shape_arranged = torch.arange(input_shape[1], device=attention_mask.device)[None, :] + # Enable all boxes to attend to each other (before the sep token) + # Ensure that the boxes are not attending to the padding tokens + boxes_end_mask = input_shape_arranged < box_ends[:, None] + boxes_start_mask = input_shape_arranged >= box_starts[:, None] + boxes_mask = boxes_end_mask & boxes_start_mask + boxes_mask = boxes_mask.unsqueeze(1).unsqueeze(1) # Enable proper broadcasting + attention_mask = attention_mask.masked_fill(boxes_mask, 0) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self._use_flash_attention_2: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + # embed positions + positions = self.embed_positions(input, past_key_values_length) + + hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + hidden_states = self.layernorm_embedding(hidden_states) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {attn_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class MBartOrderDecoderWrapper(MBartPreTrainedModel): + """ + This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is + used in combination with the [`EncoderDecoderModel`] framework. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = MBartOrderDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class MBartOrder(MBartForCausalLM): + config_class = MBartOrderConfig + _tied_weights_keys = [] + + def __init__(self, config, **kwargs): + config = copy.deepcopy(config) + config.is_decoder = True + config.is_encoder_decoder = False + MBartPreTrainedModel.__init__(self, config) + self.model = MBartOrderDecoderWrapper(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_boxes: torch.LongTensor = None, + input_boxes_mask: Optional[torch.Tensor] = None, + input_boxes_counts: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + input_boxes=input_boxes, + input_boxes_mask=input_boxes_mask, + input_boxes_counts=input_boxes_counts, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) \ No newline at end of file diff --git a/surya/model/ordering/encoder.py b/surya/model/ordering/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ff001b135a558bd3e810a4ffb45ab6de765bfc3a --- /dev/null +++ b/surya/model/ordering/encoder.py @@ -0,0 +1,83 @@ +from torch import nn +import torch +from typing import Optional, Tuple, Union +import collections +import math + +from transformers import DonutSwinPreTrainedModel +from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \ + DonutSwinEncoder + +from surya.model.ordering.config import VariableDonutSwinConfig + +class VariableDonutSwinEmbeddings(DonutSwinEmbeddings): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False, **kwargs): + super().__init__(config, use_mask_token) + + self.patch_embeddings = DonutSwinPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + self.position_embeddings = None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + + self.row_embeddings = None + self.column_embeddings = None + if config.use_2d_embeddings: + self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim)) + self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim)) + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs + ) -> Tuple[torch.Tensor]: + + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + # Layernorm across the last dimension (each patch is a single row) + embeddings = self.norm(embeddings) + batch_size, seq_len, embed_dim = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + embeddings = embeddings + self.position_embeddings[:, :seq_len, :] + + if self.row_embeddings is not None and self.column_embeddings is not None: + # Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ... + row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1) + column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1) + + embeddings = embeddings + row_embeddings + column_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +class VariableDonutSwinModel(DonutSwinModel): + config_class = VariableDonutSwinConfig + def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) + + self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() \ No newline at end of file diff --git a/surya/model/ordering/encoderdecoder.py b/surya/model/ordering/encoderdecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f7351f11f533f01bad9a74cc5ebec7ca272ba8dd --- /dev/null +++ b/surya/model/ordering/encoderdecoder.py @@ -0,0 +1,90 @@ +from typing import Optional, Union, Tuple, List + +import torch +from transformers import VisionEncoderDecoderModel +from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput + + +class OrderVisionEncoderDecoderModel(VisionEncoderDecoderModel): + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + decoder_input_boxes: torch.LongTensor = None, + # Shape (batch_size, num_boxes, 4), all coords scaled 0 - 1000, with 1001 as padding + decoder_input_boxes_mask: torch.LongTensor = None, # Shape (batch_size, num_boxes), 0 if padding, 1 otherwise + decoder_input_boxes_counts: torch.LongTensor = None, # Shape (batch_size), number of boxes in each image + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[List[List[int]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + if encoder_outputs is None: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + # else: + encoder_attention_mask = None + + # Decode + decoder_outputs = self.decoder( + input_boxes=decoder_input_boxes, + input_boxes_mask=decoder_input_boxes_mask, + input_boxes_counts=decoder_input_boxes_counts, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + use_cache=use_cache, + past_key_values=past_key_values, + return_dict=return_dict, + labels=labels, + **kwargs_decoder, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqLMOutput( + loss=decoder_outputs.loss, + logits=decoder_outputs.logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) diff --git a/surya/model/ordering/model.py b/surya/model/ordering/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8c92fee9784e330c39433414168a0eb1d697913f --- /dev/null +++ b/surya/model/ordering/model.py @@ -0,0 +1,34 @@ +from transformers import DetrConfig, BeitConfig, DetrImageProcessor, VisionEncoderDecoderConfig, AutoModelForCausalLM, \ + AutoModel +from surya.model.ordering.config import MBartOrderConfig, VariableDonutSwinConfig +from surya.model.ordering.decoder import MBartOrder +from surya.model.ordering.encoder import VariableDonutSwinModel +from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel +from surya.model.ordering.processor import OrderImageProcessor +from surya.settings import settings + + +def load_model(checkpoint=settings.ORDER_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): + config = VisionEncoderDecoderConfig.from_pretrained(checkpoint) + + decoder_config = vars(config.decoder) + decoder = MBartOrderConfig(**decoder_config) + config.decoder = decoder + + encoder_config = vars(config.encoder) + encoder = VariableDonutSwinConfig(**encoder_config) + config.encoder = encoder + + # Get transformers to load custom model + AutoModel.register(MBartOrderConfig, MBartOrder) + AutoModelForCausalLM.register(MBartOrderConfig, MBartOrder) + AutoModel.register(VariableDonutSwinConfig, VariableDonutSwinModel) + + model = OrderVisionEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) + assert isinstance(model.decoder, MBartOrder) + assert isinstance(model.encoder, VariableDonutSwinModel) + + model = model.to(device) + model = model.eval() + print(f"Loaded reading order model {checkpoint} on device {device} with dtype {dtype}") + return model \ No newline at end of file diff --git a/surya/model/ordering/processor.py b/surya/model/ordering/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f463be058f554e2aeee7c65e51d1fd9f5bbac6 --- /dev/null +++ b/surya/model/ordering/processor.py @@ -0,0 +1,156 @@ +from copy import deepcopy +from typing import Dict, Union, Optional, List, Tuple + +import torch +from torch import TensorType +from transformers import DonutImageProcessor, DonutProcessor +from transformers.image_processing_utils import BatchFeature +from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \ + valid_images, to_numpy_array +import numpy as np +from PIL import Image +import PIL +from surya.settings import settings + + +def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT): + processor = OrderImageProcessor.from_pretrained(checkpoint) + processor.size = settings.ORDER_IMAGE_SIZE + box_size = 1024 + max_tokens = 256 + processor.token_sep_id = max_tokens + box_size + 1 + processor.token_pad_id = max_tokens + box_size + 2 + processor.max_boxes = settings.ORDER_MAX_BOXES - 1 + processor.box_size = {"height": box_size, "width": box_size} + return processor + + +class OrderImageProcessor(DonutImageProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.patch_size = kwargs.get("patch_size", (4, 4)) + + def process_inner(self, images: List[np.ndarray]): + images = [img.transpose(2, 0, 1) for img in images] # convert to CHW format + + assert images[0].shape[0] == 3 # RGB input images, channel dim last + + # Convert to float32 for rescale/normalize + images = [img.astype(np.float32) for img in images] + + # Rescale and normalize + images = [ + self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST) + for img in images + ] + images = [ + self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST) + for img in images + ] + + return images + + def process_boxes(self, boxes): + padded_boxes = [] + box_masks = [] + box_counts = [] + for b in boxes: + # Left pad for generation + padded_b = deepcopy(b) + padded_b.append([self.token_sep_id] * 4) # Sep token to indicate start of label predictions + padded_boxes.append(padded_b) + + max_boxes = max(len(b) for b in padded_boxes) + for i in range(len(padded_boxes)): + pad_len = max_boxes - len(padded_boxes[i]) + box_len = len(padded_boxes[i]) + box_mask = [0] * pad_len + [1] * box_len + padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i] + padded_boxes[i] = padded_box + box_masks.append(box_mask) + box_counts.append([pad_len, max_boxes]) + + return padded_boxes, box_masks, box_counts + + def resize_img_and_boxes(self, img, boxes): + orig_dim = img.size + new_size = (self.size["width"], self.size["height"]) + img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size + img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size + + img = np.asarray(img, dtype=np.uint8) + + width, height = orig_dim + box_width, box_height = self.box_size["width"], self.box_size["height"] + for box in boxes: + # Rescale to 0-1024 + box[0] = box[0] / width * box_width + box[1] = box[1] / height * box_height + box[2] = box[2] / width * box_width + box[3] = box[3] / height * box_height + + if box[0] < 0: + box[0] = 0 + if box[1] < 0: + box[1] = 0 + if box[2] > box_width: + box[2] = box_width + if box[3] > box_height: + box[3] = box_height + + return img, boxes + + def preprocess( + self, + images: ImageInput, + boxes: List[List[int]], + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_thumbnail: bool = None, + do_align_long_axis: bool = None, + do_pad: bool = None, + random_padding: bool = False, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + new_images = [] + new_boxes = [] + for img, box in zip(images, boxes): + if len(box) > self.max_boxes: + raise ValueError(f"Too many boxes, max is {self.max_boxes}") + img, box = self.resize_img_and_boxes(img, box) + new_images.append(img) + new_boxes.append(box) + + images = new_images + boxes = new_boxes + + # Convert to numpy for later processing steps + images = [np.array(image) for image in images] + + images = self.process_inner(images) + boxes, box_mask, box_counts = self.process_boxes(boxes) + data = { + "pixel_values": images, + "input_boxes": boxes, + "input_boxes_mask": box_mask, + "input_boxes_counts": box_counts, + } + return BatchFeature(data=data, tensor_type=return_tensors) \ No newline at end of file diff --git a/surya/model/recognition/config.py b/surya/model/recognition/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9ed750b35047a4484d7d07f5b28456ff192eb082 --- /dev/null +++ b/surya/model/recognition/config.py @@ -0,0 +1,348 @@ +from dataclasses import dataclass + +import torch +from transformers import PretrainedConfig +from transformers.utils import ModelOutput + + +class SuryaOCRConfig(PretrainedConfig): + model_type = "vision-encoder-decoder" + is_composition = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + encoder_config = kwargs.pop("encoder") + decoder_config = kwargs.pop("decoder") + + self.encoder = encoder_config + self.decoder = decoder_config + self.is_encoder_decoder = True + + if isinstance(decoder_config, dict): + self.decoder_start_token_id = decoder_config["bos_token_id"] + self.pad_token_id = decoder_config["pad_token_id"] + self.eos_token_id = decoder_config["eos_token_id"] + else: + self.decoder_start_token_id = decoder_config.bos_token_id + self.pad_token_id = decoder_config.pad_token_id + self.eos_token_id = decoder_config.eos_token_id + + +class DonutSwinConfig(PretrainedConfig): + model_type = "donut-swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=(256, 896), + patch_size=4, + num_channels=3, + embed_dim=128, + depths=[2, 2, 14, 2], + num_heads=[4, 8, 16, 32], + num_kv_heads=[1, 2, 4, 8], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + encoder_length=256, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.encoder_length = encoder_length + + +class SuryaOCRDecoderConfig(PretrainedConfig): + model_type = "surya_ocr" + + def __init__( + self, + num_hidden_layers=10, + vocab_size=65792, + hidden_size=1024, + intermediate_size=4 * 1024, + num_attention_heads=16, + lru_width=None, + attention_window_size=16, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=1, + hidden_activation="gelu_pytorch_tanh", + rope_theta=10000.0, + block_types=("attention",), + cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + self_attn_layers=(0, 1, 3, 5, 7, 9), + global_attn_layers=(0, 1, 3, 5, 7, 9), + attention_dropout=0.0, + num_key_value_heads=2, + attention_bias=False, + w_init_variance_scale=0.01, + init_std=0.02, + tie_word_embeddings=False, + aux_heads=0, # How many n-token-ahead heads to add + encoder_hidden_size=1024, + causal=False, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.block_types = list(block_types) + self.hidden_activation = hidden_activation + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") + self.cross_attn_layers = cross_attn_layers + self.self_attn_layers = self_attn_layers + self.global_attn_layers = global_attn_layers + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.w_init_variance_scale = w_init_variance_scale + self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers + self.init_std = init_std + self.tie_word_embeddings = tie_word_embeddings + self.aux_heads = aux_heads + self.encoder_hidden_size = encoder_hidden_size + self.causal = causal + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + @property + def layers_block_type(self): + return (self.block_types * 100)[: self.num_hidden_layers] + + +class SuryaOCRTextEncoderConfig(PretrainedConfig): + model_type = "surya_ocr" + + def __init__( + self, + num_hidden_layers=10, + vocab_size=65792, + hidden_size=1024, + intermediate_size=4 * 1024, + num_attention_heads=16, + lru_width=None, + attention_window_size=16, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=1, + hidden_activation="gelu_pytorch_tanh", + rope_theta=10000.0, + block_types=("attention",), + cross_attn_layers=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + self_attn_layers=(0, 1, 3, 5, 7, 9), + global_attn_layers=(0, 1, 3, 5, 7, 9), + attention_dropout=0.0, + num_key_value_heads=2, + attention_bias=False, + w_init_variance_scale=0.01, + init_std=0.02, + tie_word_embeddings=False, + aux_heads=0, # How many n-token-ahead heads to add + encoder_hidden_size=1024, + iteration_count=1, + causal=False, + query_token_count=128, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.block_types = list(block_types) + self.hidden_activation = hidden_activation + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") + self.cross_attn_layers = cross_attn_layers + self.self_attn_layers = self_attn_layers + self.global_attn_layers = global_attn_layers + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.w_init_variance_scale = w_init_variance_scale + self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers + self.init_std = init_std + self.tie_word_embeddings = tie_word_embeddings + self.aux_heads = aux_heads + self.encoder_hidden_size = encoder_hidden_size + self.iteration_count = iteration_count + self.causal = causal + self.query_token_count = query_token_count + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + @property + def layers_block_type(self): + return (self.block_types * 100)[: self.num_hidden_layers] + +TOTAL_TOKENS = 65536 +TOKEN_OFFSET = 3 # Pad, eos, bos +SPECIAL_TOKENS = 253 +TOTAL_VOCAB_SIZE = TOTAL_TOKENS + TOKEN_OFFSET + SPECIAL_TOKENS +LANGUAGE_MAP = { + 'af': 0, + 'am': 1, + 'ar': 2, + 'as': 3, + 'az': 4, + 'be': 5, + 'bg': 6, + 'bn': 7, + 'br': 8, + 'bs': 9, + 'ca': 10, + 'cs': 11, + 'cy': 12, + 'da': 13, + 'de': 14, + 'el': 15, + 'en': 16, + 'eo': 17, + 'es': 18, + 'et': 19, + 'eu': 20, + 'fa': 21, + 'fi': 22, + 'fr': 23, + 'fy': 24, + 'ga': 25, + 'gd': 26, + 'gl': 27, + 'gu': 28, + 'ha': 29, + 'he': 30, + 'hi': 31, + 'hr': 32, + 'hu': 33, + 'hy': 34, + 'id': 35, + 'is': 36, + 'it': 37, + 'ja': 38, + 'jv': 39, + 'ka': 40, + 'kk': 41, + 'km': 42, + 'kn': 43, + 'ko': 44, + 'ku': 45, + 'ky': 46, + 'la': 47, + 'lo': 48, + 'lt': 49, + 'lv': 50, + 'mg': 51, + 'mk': 52, + 'ml': 53, + 'mn': 54, + 'mr': 55, + 'ms': 56, + 'my': 57, + 'ne': 58, + 'nl': 59, + 'no': 60, + 'om': 61, + 'or': 62, + 'pa': 63, + 'pl': 64, + 'ps': 65, + 'pt': 66, + 'ro': 67, + 'ru': 68, + 'sa': 69, + 'sd': 70, + 'si': 71, + 'sk': 72, + 'sl': 73, + 'so': 74, + 'sq': 75, + 'sr': 76, + 'su': 77, + 'sv': 78, + 'sw': 79, + 'ta': 80, + 'te': 81, + 'th': 82, + 'tl': 83, + 'tr': 84, + 'ug': 85, + 'uk': 86, + 'ur': 87, + 'uz': 88, + 'vi': 89, + 'xh': 90, + 'yi': 91, + 'zh': 92, + "_math": 93 +} \ No newline at end of file diff --git a/surya/model/recognition/decoder.py b/surya/model/recognition/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..071f701448bbfb8ec80697c9bb5b65c5c747479c --- /dev/null +++ b/surya/model/recognition/decoder.py @@ -0,0 +1,695 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.utils import ModelOutput + +from surya.model.recognition.config import SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS + +from surya.settings import settings + +_MAX_SQRT_GRADIENT = 1000.0 + + +@dataclass +class OCRModelOutput(ModelOutput): + logits: torch.Tensor + aux_logits: torch.Tensor | None = None + hidden_states: torch.Tensor | None = None + + +class SuryaOCRDecoderRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst SuryaOCRDecoder is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +ALL_LAYERNORM_LAYERS.append(SuryaOCRDecoderRMSNorm) + + +class SuryaOCRDecoderRotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000, device=None): + super().__init__() + self.dim = dim + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaOCRDecoder + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class SuryaOCRDecoderSdpaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper + Modified for GQA + """ + + def __init__(self, config: SuryaOCRDecoderConfig): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.rotary_emb = SuryaOCRDecoderRotaryEmbedding( + self.head_dim, + base=config.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Encoder attention mask currently ignored + + bsz, q_len, _ = hidden_states.size() + _, v_len, _ = encoder_hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + + if self.key_states is None: + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if use_cache: + self._update_cache(key_states, value_states) + else: + key_states = self.key_states + value_states = self.value_states + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_mask=None, + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.head_dim**-0.5, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output + + def _setup_cache(self, batch_size, device, dtype=None): + # Setup initial caches + self.value_states = None + self.key_states = None + + @torch.no_grad() + def _update_cache(self, key_states, value_states, **cache_kwargs): + self.value_states = value_states + self.key_states = key_states + + +class SuryaOCRDecoderSdpaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: SuryaOCRDecoderConfig): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.rotary_emb = SuryaOCRDecoderRotaryEmbedding( + self.head_dim, + base=config.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: bool = False, + window_attn: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Final is bsz, num_attention_heads, seq_len, head_dim + query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if use_cache and hasattr(self, "key_states"): + cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn} + key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + # Mask is batch, head, seq_len, kv_len + causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] + current_cache_position = cache_position[-1].item() if cache_position is not None else None + if current_cache_position and settings.RECOGNITION_STATIC_CACHE: + # Mask out future cache positions + position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device) + position_mask[:, :, :, :current_cache_position + 1] = False + causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.head_dim**-0.5, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output + + def _setup_cache(self, batch_size, device, dtype=None): + if dtype is None and self.config.torch_dtype is not None: + dtype = self.config.torch_dtype + dtype = dtype if dtype is not None else torch.float32 + + # Setup initial caches + self.value_states = None + self.key_states = None + + if settings.RECOGNITION_STATIC_CACHE: + cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim) + self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) + self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) + + def _update_static_cache(self, key_states, value_states, **cache_kwargs): + cache_position = cache_kwargs.get("cache_position") + k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device) + + k_out[:, :, cache_position] = key_states.to(k_out.dtype) + v_out[:, :, cache_position] = value_states.to(v_out.dtype) + + self.key_states, self.value_states = k_out, v_out + return k_out, v_out + + def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs): + k_out = key_states + if self.key_states is not None: + k_out = torch.cat([self.key_states, key_states], dim=2) + + v_out = value_states + if self.value_states is not None: + v_out = torch.cat([self.value_states, value_states], dim=2) + + self.key_states, self.value_states = k_out, v_out + return k_out, v_out + + @torch.no_grad() + def _update_cache(self, key_states, value_states, **cache_kwargs): + if settings.RECOGNITION_STATIC_CACHE: + return self._update_static_cache(key_states, value_states, **cache_kwargs) + + return self._update_dynamic_cache(key_states, value_states, **cache_kwargs) + + +class SuryaOCRDecoderMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_activation is None: + config.hidden_activation = "gelu_pytorch_tanh" + hidden_activation = config.hidden_activation + self.act_fn = ACT2FN[hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class SuryaOCRDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + super().__init__() + self.cross_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.temporal_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.temporal_block = None + if layer_idx in config.self_attn_layers: + self.temporal_block = SuryaOCRDecoderSdpaAttention(config) + + self.cross_attn_block = None + if layer_idx in config.cross_attn_layers: + self.cross_attn_block = SuryaOCRDecoderSdpaCrossAttention(config) + + self.window_attn = layer_idx not in config.global_attn_layers + self.channel_pre_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp_block = SuryaOCRDecoderMlp(config) + + def forward( + self, + activations: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + cache_position: torch.Tensor = None, + use_cache: bool = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raw_activations = activations + + if self.cross_attn_block is not None: + # Do cross-attention on encoder outputs + cross_attn_inputs = self.cross_pre_norm(activations) + cross_attn_path = self.cross_attn_block( + cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache + ) + cross_attn_output = cross_attn_path + raw_activations + else: + cross_attn_output = raw_activations + + if self.temporal_block is not None: + inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences + hidden_states = self.temporal_block( + inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn + ) + + residual = hidden_states + raw_activations + else: + residual = cross_attn_output + + hidden_states = self.channel_pre_norm(residual) + hidden_states = self.mlp_block(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states + + +class SuryaOCRDecoderPreTrainedModel(PreTrainedModel): + config_class = SuryaOCRDecoderConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["SuryaOCRDecoderLayer"] + _skip_keys_device_placement = ["cache"] + _supports_flash_attn_2 = False + _supports_sdpa = False # we can't compare with eager for now + _supports_cache_class = True + _supports_quantized_cache = True + + def _init_weights(self, module): + if isinstance(module, SuryaOCRDecoderSdpaAttention): + torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std) + torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std) + torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std) + + torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std) + elif isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) + if getattr(module, "bias", None) is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _setup_cache(self, config, batch, device, dtype): + layers = getattr(self, "model", self).layers + for layer in layers: + if layer.temporal_block: + layer.temporal_block._setup_cache(batch, device, dtype) + if layer.cross_attn_block: + layer.cross_attn_block._setup_cache(batch, device, dtype) + + def reset_cache(self, batch, device, dtype): + pass + + def _tie_weights(self): + pass + + def tie_weights(self): + pass + + +class SuryaOCRDecoderModel(SuryaOCRDecoderPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaOCRDecoderDecoderLayer`] + + Args: + config: SuryaOCRDecoderConfig + """ + + def __init__(self, config: SuryaOCRDecoderConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.causal = config.causal + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [SuryaOCRDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.final_norm = SuryaOCRDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.register_buffer( + "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False + ) + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings + def get_input_embeddings(self): + return self.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + prefill: bool = False + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + if use_cache and prefill: + self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + all_hidden_states = () if output_hidden_states else None + for i, residual_block in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache + ) + else: + hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache) + + hidden_states = self.final_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + # Ignore copy + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if not self.causal: + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = max(settings.RECOGNITION_MAX_TOKENS, sequence_length) + + diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + causal_mask = diagonal + if sequence_length != 1: + # Select the upper triangular part of the matrix, but unmask current token (the diagonal) + # triu will be the min_dtype, everything else is 0 (attended to) + causal_mask = torch.triu(diagonal, diagonal=1) + + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + # Mask positions in the causal mask that are masked in the attention mask + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if attention_mask is not None and attention_mask.device.type == "cuda": + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class SuryaOCRDecoder(SuryaOCRDecoderPreTrainedModel): + _tied_weights_keys = None + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = SuryaOCRDecoderModel(config) + self.vocab_size = config.vocab_size + aux_heads = config.aux_heads if config.aux_heads is not None else 0 + lm_heads = aux_heads + 1 + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size * lm_heads, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + prefill: bool = False, + **kwargs + ) -> Union[Tuple, OCRModelOutput]: + outputs = self.model( + input_ids=input_ids, + cache_position=cache_position, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_hidden_states=True, + return_dict=True, + prefill=prefill, + ) + + hidden_states = outputs[0] + all_logits = self.lm_head(hidden_states) + all_logits = torch.split(all_logits, self.vocab_size, dim=-1) + logits = all_logits[0] + aux_logits = all_logits[1:] if len(all_logits) > 1 else None + + return OCRModelOutput( + logits=logits, + aux_logits=aux_logits, + hidden_states=outputs.hidden_states, + ) + +@dataclass +class TextEncoderOutput(CausalLMOutput): + hidden_states: torch.FloatTensor = None + + +class SuryaOCRTextEncoder(SuryaOCRDecoderPreTrainedModel): + _tied_weights_keys = None + config_class = SuryaOCRTextEncoderConfig + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = SuryaOCRDecoderModel(config) + self.vocab_size = config.vocab_size + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, CausalLMOutput]: + outputs = self.model( + input_ids=input_ids, + cache_position=cache_position, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_hidden_states=True, + return_dict=True, + ) + + return TextEncoderOutput( + hidden_states=outputs.last_hidden_state, + ) \ No newline at end of file diff --git a/surya/model/recognition/encoder.py b/surya/model/recognition/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..85fb01c80a6b4d25b39a235b3a955c103aaade4f --- /dev/null +++ b/surya/model/recognition/encoder.py @@ -0,0 +1,852 @@ +""" EfficientViT (by MIT Song Han's Lab) + +Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition` + - https://arxiv.org/abs/2205.14756 + +Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py +Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit +""" + +import collections.abc +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer +from transformers.utils import ModelOutput +from surya.model.recognition.config import DonutSwinConfig + +_EXPECTED_OUTPUT_SHAPE = [1, 49, 1024] + + +@dataclass +# Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin +class DonutSwinEncoderOutput(ModelOutput): + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class DonutSwinModelOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + + +# Copied from transformers.models.swin.modeling_swin.window_partition +def window_partition(input_feature, window_size): + """ + Partitions the given input into windows. + """ + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, height // window_size, window_size, width // window_size, window_size, num_channels + ) + windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.window_reverse +def window_reverse(windows, window_size, height, width): + """ + Merges windows to produce higher resolution features. + """ + num_channels = windows.shape[-1] + windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels) + windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels) + return windows + + +# Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin +class DonutSwinEmbeddings(nn.Module): + """ + Construct the patch and position embeddings. Optionally, also the mask token. + """ + + def __init__(self, config, use_mask_token=False): + super().__init__() + + self.patch_embeddings = DonutSwinPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.patch_grid = self.patch_embeddings.grid_size + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None + + if config.use_absolute_embeddings: + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim)) + else: + self.position_embeddings = None + + self.norm = nn.LayerNorm(config.embed_dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor], + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> Tuple[torch.Tensor]: + _, num_channels, height, width = pixel_values.shape + embeddings, output_dimensions = self.patch_embeddings(pixel_values) + embeddings = self.norm(embeddings) + batch_size, seq_len, _ = embeddings.size() + + if bool_masked_pos is not None: + mask_tokens = self.mask_token.expand(batch_size, seq_len, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + if self.position_embeddings is not None: + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings[:, :seq_len] + + embeddings = self.dropout(embeddings) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->DonutSwin +class DonutSwinPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.embed_dim + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + + def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: + _, num_channels, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions + + +# Copied from transformers.models.swin.modeling_swin.SwinPatchMerging +class DonutSwinPatchMerging(nn.Module): + """ + Patch Merging Layer. + + Args: + input_resolution (`Tuple[int]`): + Resolution of input feature. + dim (`int`): + Number of input channels. + norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`): + Normalization layer class. + """ + + def __init__(self, input_resolution: Tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None: + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature: torch.Tensor, input_dimensions: Tuple[int, int]) -> torch.Tensor: + height, width = input_dimensions + # `dim` is height * width + batch_size, dim, num_channels = input_feature.shape + + input_feature = input_feature.view(batch_size, height, width, num_channels) + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] + # batch_size height/2 width/2 4*num_channels + input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) + input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C + + input_feature = self.norm(input_feature) + input_feature = self.reduction(input_feature) + + return input_feature + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.swin.modeling_swin.SwinDropPath +class DonutSwinDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin +class DonutSwinSelfAttention(nn.Module): + def __init__(self, config, dim, num_heads, num_kv_heads, window_size): + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})" + ) + + self.num_attention_heads = num_heads + self.num_kv_heads = num_kv_heads + self.kv_repeats = self.num_attention_heads // self.num_kv_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.kv_head_size = self.num_kv_heads * self.attention_head_size + self.window_size = ( + window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) + ) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + self.register_buffer("relative_position_index", relative_position_index) + + self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias) + self.value = nn.Linear(self.all_head_size, self.kv_head_size, bias=config.qkv_bias) + + self.dropout_p = config.attention_probs_dropout_prob + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def transpose_kv_for_scores(self, x, repeats): + new_x_shape = x.size()[:-1] + (self.num_kv_heads, self.attention_head_size) + x = x.view(new_x_shape) + x = x.repeat(1, 1, repeats, 1) # repeat the values for each key-value head to match query dim + return x.permute(0, 2, 1, 3).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + batch_size, dim, num_channels = hidden_states.shape + mixed_query_layer = self.query(hidden_states) + + # Final is (batch_size, num_attention_heads, seq_len, attention_head_size) + key_layer = self.transpose_kv_for_scores(self.key(hidden_states), self.kv_repeats) + value_layer = self.transpose_kv_for_scores(self.value(hidden_states), self.kv_repeats) + query_layer = self.transpose_for_scores(mixed_query_layer) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + if attention_mask is None: + attention_mask = relative_position_bias + else: + mask_shape = attention_mask.shape[0] + repeat_count = (batch_size // mask_shape) + attention_mask = attention_mask.repeat(repeat_count, 1, 1).unsqueeze(1) + attention_mask = attention_mask + relative_position_bias + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer.contiguous(), + key_layer.contiguous(), + value_layer.contiguous(), + attn_mask=attention_mask, + dropout_p=self.dropout_p if self.training else 0.0, + scale=self.attention_head_size**-0.5, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, dim, num_channels) + + outputs = (attn_output,) + return outputs + +# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput +class DonutSwinSelfOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, dim) + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin +class DonutSwinAttention(nn.Module): + def __init__(self, config, dim, num_heads, num_kv_heads, window_size): + super().__init__() + self.self = DonutSwinSelfAttention(config, dim, num_heads, num_kv_heads, window_size) + self.output = DonutSwinSelfOutput(config, dim) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinIntermediate +class DonutSwinIntermediate(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(dim, int(config.mlp_ratio * dim)) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinOutput +class DonutSwinOutput(nn.Module): + def __init__(self, config, dim): + super().__init__() + self.dense = nn.Linear(int(config.mlp_ratio * dim), dim) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +# Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin +class DonutSwinLayer(nn.Module): + def __init__(self, config, dim, input_resolution, num_heads, num_kv_heads, shift_size=0): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.shift_size = shift_size + self.window_size = config.window_size + self.input_resolution = input_resolution + self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = DonutSwinAttention(config, dim, num_heads, num_kv_heads, window_size=self.window_size) + self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() + self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.intermediate = DonutSwinIntermediate(config, dim) + self.output = DonutSwinOutput(config, dim) + + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = int(0) + self.window_size = ( + torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution) + ) + + def get_attn_mask(self, height, width, dtype, device): + if self.shift_size > 0: + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device) + height_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + width_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + count = 0 + for height_slice in height_slices: + for width_slice in width_slices: + img_mask[:, height_slice, width_slice, :] = count + count += 1 + + mask_windows = window_partition(img_mask, self.window_size) + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if not always_partition: + self.set_shift_and_window_size(input_dimensions) + else: + pass + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() + shortcut = hidden_states + + hidden_states = self.layernorm_before(hidden_states) + + hidden_states = hidden_states.view(batch_size, height, width, channels) + + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + + _, height_pad, width_pad, _ = hidden_states.shape + # cyclic shift + if self.shift_size > 0: + shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_hidden_states = hidden_states + + # partition windows + hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) + hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask( + height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device + ) + + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions + ) + + attention_output = attention_outputs[0] + + attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) + + # reverse cyclic shift + if self.shift_size > 0: + attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attention_windows = shifted_windows + + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + + attention_windows = attention_windows.view(batch_size, height * width, channels) + + hidden_states = shortcut + self.drop_path(attention_windows) + + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + layer_output = hidden_states + self.output(layer_output) + + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin +class DonutSwinStage(nn.Module): + def __init__(self, config, dim, input_resolution, depth, num_heads, num_kv_heads, drop_path, downsample): + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList( + [ + DonutSwinLayer( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + shift_size=0 if (i % 2 == 0) else config.window_size // 2, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +# Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin +class DonutSwinEncoder(nn.Module): + def __init__(self, config, grid_size): + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] + self.layers = nn.ModuleList( + [ + DonutSwinStage( + config=config, + dim=int(config.embed_dim * 2**i_layer), + input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), + depth=config.depths[i_layer], + num_heads=config.num_heads[i_layer], + num_kv_heads=config.num_kv_heads[i_layer], + drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])], + downsample=DonutSwinPatchMerging if (i_layer < self.num_layers - 1) else None, + ) + for i_layer in range(self.num_layers) + ] + ) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: Tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_hidden_states_before_downsampling: Optional[bool] = False, + always_partition: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, DonutSwinEncoderOutput]: + all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + always_partition, + ) + else: + layer_outputs = layer_module( + hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition + ) + + hidden_states = layer_outputs[0] + hidden_states_before_downsampling = layer_outputs[1] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + if output_hidden_states and output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states_before_downsampling.shape + # rearrange b (h w) c -> b c h w + # here we use the original (not downsampled) height and width + reshaped_hidden_state = hidden_states_before_downsampling.view( + batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size + ) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states_before_downsampling,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + elif output_hidden_states and not output_hidden_states_before_downsampling: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[3:] + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return DonutSwinEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, + ) + + +# Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->DonutSwin +class DonutSwinPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DonutSwinConfig + base_model_prefix = "swin" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["DonutSwinStage"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class DonutSwinModel(DonutSwinPreTrainedModel): + def __init__(self, config, add_pooling_layer=True, use_mask_token=False): + super().__init__(config) + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1)) + + self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid) + + self.position_embeddings = nn.Parameter(torch.zeros(1, config.encoder_length, config.hidden_size)) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, DonutSwinModelOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, len(self.config.depths)) + + embedding_output, input_dimensions = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state += self.position_embeddings[:, :last_hidden_state.size(1), :] + + return DonutSwinModelOutput( + last_hidden_state=last_hidden_state, + ) \ No newline at end of file diff --git a/surya/model/recognition/encoderdecoder.py b/surya/model/recognition/encoderdecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..0ec83ef81b009d713cc82fcf898241308b653db6 --- /dev/null +++ b/surya/model/recognition/encoderdecoder.py @@ -0,0 +1,145 @@ +from typing import Optional, Union, Tuple + +import torch +from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig +from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput +from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right +from surya.model.recognition.encoder import DonutSwinModel +from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder + + +class OCREncoderDecoderModel(PreTrainedModel): + config_class = VisionEncoderDecoderConfig + base_model_prefix = "vision_encoder_decoder" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[PreTrainedModel] = None, + decoder: Optional[PreTrainedModel] = None, + text_encoder: Optional[PreTrainedModel] = None, + ): + # initialize with config + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + config.decoder.tie_word_embeddings = False + super().__init__(config) + + if encoder is None: + encoder = DonutSwinModel(config.encoder) + + if decoder is None: + decoder = SuryaOCRDecoder(config.decoder, attn_implementation=config._attn_implementation) + + if text_encoder is None: + text_encoder = SuryaOCRTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation) + + self.encoder = encoder + self.decoder = decoder + self.text_encoder = text_encoder + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + self.text_encoder.config = self.config.text_encoder + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_cache_position: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + if encoder_outputs is None: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + encoder_outputs = self.encoder( + pixel_values=pixel_values, + **kwargs_encoder, + ) + elif isinstance(encoder_outputs, tuple): + encoder_outputs = BaseModelOutput(*encoder_outputs) + + encoder_hidden_states = encoder_outputs[0] + + # optionally project encoder_hidden_states + if ( + self.encoder.config.hidden_size != self.decoder.config.hidden_size + and self.decoder.config.cross_attention_hidden_size is None + ): + encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) + + # else: + encoder_attention_mask = None + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + cache_position=decoder_cache_position, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + **kwargs_decoder, + ) + + return Seq2SeqLMOutput( + logits=decoder_outputs.logits, + decoder_hidden_states=decoder_outputs.hidden_states, + encoder_last_hidden_state=encoder_outputs.last_hidden_state + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) + decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + input_dict = { + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_input_ids": decoder_inputs["input_ids"], + "encoder_outputs": encoder_outputs, + "past_key_values": decoder_inputs["past_key_values"], + "use_cache": use_cache, + } + return input_dict + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" + " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past_key_values, beam_idx): + # apply decoder cache reordering here + return self.decoder._reorder_cache(past_key_values, beam_idx) \ No newline at end of file diff --git a/surya/model/recognition/model.py b/surya/model/recognition/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf9bd75ddda44733b82e38983b3ea0548f270ca --- /dev/null +++ b/surya/model/recognition/model.py @@ -0,0 +1,49 @@ +import warnings + +import torch + +warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated") + +import logging +logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) + +from typing import List, Optional, Tuple +from surya.model.recognition.encoderdecoder import OCREncoderDecoderModel +from surya.model.recognition.config import DonutSwinConfig, SuryaOCRConfig, SuryaOCRDecoderConfig, SuryaOCRTextEncoderConfig +from surya.model.recognition.encoder import DonutSwinModel +from surya.model.recognition.decoder import SuryaOCRDecoder, SuryaOCRTextEncoder +from surya.settings import settings + +if not settings.ENABLE_EFFICIENT_ATTENTION: + print("Efficient attention is disabled. This will use significantly more VRAM.") + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + + +def load_model(checkpoint=settings.RECOGNITION_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): + + config = SuryaOCRConfig.from_pretrained(checkpoint) + decoder_config = config.decoder + decoder = SuryaOCRDecoderConfig(**decoder_config) + config.decoder = decoder + + encoder_config = config.encoder + encoder = DonutSwinConfig(**encoder_config) + config.encoder = encoder + + text_encoder_config = config.text_encoder + text_encoder = SuryaOCRTextEncoderConfig(**text_encoder_config) + config.text_encoder = text_encoder + + model = OCREncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) + + assert isinstance(model.decoder, SuryaOCRDecoder) + assert isinstance(model.encoder, DonutSwinModel) + assert isinstance(model.text_encoder, SuryaOCRTextEncoder) + + model = model.to(device) + model = model.eval() + + print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}") + return model \ No newline at end of file diff --git a/surya/model/recognition/processor.py b/surya/model/recognition/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa20a2f5848a860c066ca27ce87601d884a2bfd --- /dev/null +++ b/surya/model/recognition/processor.py @@ -0,0 +1,206 @@ +from typing import Dict, Union, Optional, List, Iterable + +import cv2 +from torch import TensorType +from transformers import DonutImageProcessor, DonutProcessor +from transformers.image_processing_utils import BatchFeature +from transformers.image_transforms import pad, normalize +from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size +import numpy as np +from PIL import Image +import PIL +from surya.model.recognition.tokenizer import Byt5LangTokenizer +from surya.settings import settings + + +def load_processor(): + processor = SuryaProcessor() + processor.image_processor.train = False + processor.image_processor.max_size = settings.RECOGNITION_IMAGE_SIZE + processor.tokenizer.model_max_length = settings.RECOGNITION_MAX_TOKENS + return processor + + +class SuryaImageProcessor(DonutImageProcessor): + def __init__(self, *args, max_size=None, train=False, **kwargs): + super().__init__(*args, **kwargs) + + self.patch_size = kwargs.get("patch_size", (4, 4)) + self.max_size = max_size + self.train = train + + @classmethod + def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4): + max_width, max_height = size["width"], size["height"] + + resized_image = cv2.resize(image, (max_width, max_height), interpolation=interpolation) + resized_image = resized_image.transpose(2, 0, 1) + + return resized_image + + def process_inner(self, images: List[np.ndarray]): + assert images[0].shape[2] == 3 # RGB input images, channel dim last + + # Rotate if the bbox is wider than it is tall + images = [SuryaImageProcessor.align_long_axis(image, size=self.max_size, input_data_format=ChannelDimension.LAST) for image in images] + + # Verify that the image is wider than it is tall + for img in images: + assert img.shape[1] >= img.shape[0] + + # This also applies the right channel dim format, to channel x height x width + images = [SuryaImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images] + assert images[0].shape[0] == 3 # RGB input images, channel dim first + + # Convert to float32 for rescale/normalize + images = [img.astype(np.float32) for img in images] + + # Pads with 255 (whitespace) + # Pad to max size to improve performance + max_size = self.max_size + images = [ + SuryaImageProcessor.pad_image( + image=image, + size=max_size, + input_data_format=ChannelDimension.FIRST, + pad_value=settings.RECOGNITION_PAD_VALUE + ) + for image in images + ] + # Rescale and normalize + for idx in range(len(images)): + images[idx] = images[idx] * self.rescale_factor + images = [ + SuryaImageProcessor.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST) + for img in images + ] + + return images + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_thumbnail: bool = None, + do_align_long_axis: bool = None, + do_pad: bool = None, + random_padding: bool = False, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + images = make_list_of_images(images) + + # Convert to numpy for later processing steps + images = [np.array(img) for img in images] + images = self.process_inner(images) + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + @classmethod + def pad_image( + cls, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_value: float = 0.0, + ) -> np.ndarray: + output_height, output_width = size["height"], size["width"] + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + + delta_width = output_width - input_width + delta_height = output_height - input_height + + assert delta_width >= 0 and delta_height >= 0 + + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = ((pad_top, pad_bottom), (pad_left, pad_right)) + return pad(image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value) + + @classmethod + def align_long_axis( + cls, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + input_height, input_width = image.shape[:2] + output_height, output_width = size["height"], size["width"] + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + image = np.rot90(image, 3) + + return image + + @classmethod + def normalize( + cls, + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + return normalize( + image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs + ) + + +class SuryaProcessor(DonutProcessor): + def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs): + image_processor = SuryaImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT) + tokenizer = Byt5LangTokenizer() + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + if tokenizer is None: + raise ValueError("You need to specify a `tokenizer`.") + + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + self._in_target_context_manager = False + + def __call__(self, *args, **kwargs): + images = kwargs.pop("images", None) + text = kwargs.pop("text", None) + langs = kwargs.pop("langs", None) + + if len(args) > 0: + images = args[0] + args = args[1:] + + if images is None and text is None: + raise ValueError("You need to specify either an `images` or `text` input to process.") + + if images is not None: + inputs = self.image_processor(images, *args, **kwargs) + + if text is not None: + encodings = self.tokenizer(text, langs, **kwargs) + + if text is None: + return inputs + elif images is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + inputs["langs"] = encodings["langs"] + return inputs \ No newline at end of file diff --git a/surya/model/recognition/tokenizer.py b/surya/model/recognition/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f4201cda6bbf06a2235d1488c44efc0b44fe43b1 --- /dev/null +++ b/surya/model/recognition/tokenizer.py @@ -0,0 +1,120 @@ +from itertools import chain +import random +from typing import List, Optional, Tuple, Union +from tokenizers import AddedToken +from transformers import ByT5Tokenizer +import numpy as np +import torch +from surya.model.recognition.config import LANGUAGE_MAP, TOTAL_TOKENS, TOKEN_OFFSET + + +def text_to_utf16_numbers(text): + utf16_bytes = text.encode('utf-16le') # Little-endian to simplify byte order handling + + numbers = [] + + # Iterate through each pair of bytes and combine them into a single number + for i in range(0, len(utf16_bytes), 2): + # Combine two adjacent bytes into a single number + number = utf16_bytes[i] + (utf16_bytes[i + 1] << 8) + numbers.append(number) + + return numbers + + +def utf16_numbers_to_text(numbers): + byte_array = bytearray() + for number in numbers: + # Extract the two bytes from the number and add them to the byte array + byte_array.append(number & 0xFF) # Lower byte + byte_array.append((number >> 8) & 0xFF) # Upper byte + + text = byte_array.decode('utf-16le', errors="ignore") + return text + + +def _tokenize(text: str, langs: List[str] | None, eos_token_id: int = 1, add_eos: bool = False, add_bos: bool = True): + tokens = text_to_utf16_numbers(text) + tokens = [t + TOKEN_OFFSET for t in tokens] # Account for special pad, etc, tokens + + lang_list = [] + if langs: + for lang in langs: + code = LANGUAGE_MAP[lang] + lang_list.append(code + TOKEN_OFFSET + TOTAL_TOKENS) + + tokens = lang_list + tokens + + if add_bos: + tokens.insert(0, eos_token_id) + + return tokens, lang_list + + +class Byt5LangTokenizer(ByT5Tokenizer): + def __init__(self, + eos_token="", + unk_token="", + pad_token="", + model_max_length=None, + **kwargs, + ): + self.pad_token = pad_token + self.eos_token = eos_token + self.unk_token = unk_token + self.bos_token = eos_token + self.offset = TOKEN_OFFSET + + self.pad_id = 0 + self.eos_id = 1 + self.unk_id = 2 + + self.model_max_length = model_max_length + self.special_token_start = TOKEN_OFFSET + TOTAL_TOKENS + + super().__init__() + + def __call__(self, texts: List[str] | str, langs: List[List[str]] | List[str] | None = None, pad_token_id: int = 0, **kwargs): + tokenized = [] + all_langs = [] + + is_list = True + # Convert to list of lists format + if isinstance(texts, str): + texts = [texts] + is_list = False + + if langs is None: + langs = [None] * len(texts) + + if isinstance(langs[0], str): + langs = [langs] + + assert len(langs) == len(texts) + + for text, lang in zip(texts, langs): + tokens, lang_list = _tokenize(text, lang) + tokenized.append(tokens) + all_langs.append(lang_list) + + # Convert back to flat format + if not is_list: + tokenized = tokenized[0] + all_langs = all_langs[0] + + return {"input_ids": tokenized, "langs": all_langs} + + def decode( + self, + token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ) -> str: + if isinstance(token_ids, (np.ndarray, torch.Tensor)): + token_ids = token_ids.tolist() + + token_ids = [t for t in token_ids if TOKEN_OFFSET <= t < self.special_token_start] + token_ids = [t - TOKEN_OFFSET for t in token_ids] + text = utf16_numbers_to_text(token_ids) + return text diff --git a/surya/model/table_rec/config.py b/surya/model/table_rec/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a4ced243ab8be012bb38243ac8d4f36623bec1 --- /dev/null +++ b/surya/model/table_rec/config.py @@ -0,0 +1,260 @@ +from transformers import PretrainedConfig +from surya.settings import settings + +BOX_DIM = 1024 +SPECIAL_TOKENS = 7 +MAX_ROWS = 384 + + +class SuryaTableRecConfig(PretrainedConfig): + model_type = "vision-encoder-decoder" + is_composition = True + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + encoder_config = kwargs.pop("encoder") + decoder_config = kwargs.pop("decoder") + text_enc_config = kwargs.pop("text_encoder") + + self.encoder = encoder_config + self.decoder = decoder_config + self.text_encoder = text_enc_config + self.is_encoder_decoder = True + + if isinstance(decoder_config, dict): + self.decoder_start_token_id = decoder_config["bos_token_id"] + self.pad_token_id = decoder_config["pad_token_id"] + self.eos_token_id = decoder_config["eos_token_id"] + else: + self.decoder_start_token_id = decoder_config.bos_token_id + self.pad_token_id = decoder_config.pad_token_id + self.eos_token_id = decoder_config.eos_token_id + + +class DonutSwinTableRecConfig(PretrainedConfig): + model_type = "donut-swin" + + attribute_map = { + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + } + + def __init__( + self, + image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]), + patch_size=4, + num_channels=3, + embed_dim=128, + depths=[2, 2, 14, 2], + num_heads=[4, 8, 16, 32], + num_kv_heads=[4, 8, 16, 32], + window_size=8, + mlp_ratio=4.0, + qkv_bias=True, + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + drop_path_rate=0.1, + hidden_act="gelu", + use_absolute_embeddings=True, + initializer_range=0.02, + layer_norm_eps=1e-5, + encoder_length=1024, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.embed_dim = embed_dim + self.depths = depths + self.num_layers = len(depths) + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.qkv_bias = qkv_bias + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.drop_path_rate = drop_path_rate + self.hidden_act = hidden_act + self.use_absolute_embeddings = use_absolute_embeddings + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel + # this indicates the channel dimension after the last stage of the model + self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) + self.encoder_length = encoder_length + + +class SuryaTableRecDecoderConfig(PretrainedConfig): + model_type = "surya_tablerec" + + def __init__( + self, + num_hidden_layers=3, + vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS, + hidden_size=512, + intermediate_size=4 * 512, + encoder_hidden_size=1024, + num_attention_heads=8, + lru_width=None, + attention_window_size=16, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + hidden_activation="gelu_pytorch_tanh", + rope_theta=10000.0, + block_types=("attention",), + cross_attn_layers=(0, 1, 2, 3), + encoder_cross_attn_layers=(0, 1, 2, 3), + self_attn_layers=(0, 1, 2, 3), + global_attn_layers=(0, 1, 2, 3), + attention_dropout=0.0, + num_key_value_heads=4, + attention_bias=False, + w_init_variance_scale=0.01, + init_std=0.02, + tie_word_embeddings=False, + aux_heads=0, # How many n-token-ahead heads to add + causal=True, + max_classes=2 + SPECIAL_TOKENS, + max_width=1024 + SPECIAL_TOKENS, + max_height=1024 + SPECIAL_TOKENS, + out_box_size=1024, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.block_types = list(block_types) + self.hidden_activation = hidden_activation + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") + self.cross_attn_layers = cross_attn_layers + self.self_attn_layers = self_attn_layers + self.global_attn_layers = global_attn_layers + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.w_init_variance_scale = w_init_variance_scale + self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers + self.init_std = init_std + self.tie_word_embeddings = tie_word_embeddings + self.aux_heads = aux_heads + self.encoder_hidden_size=encoder_hidden_size + self.causal = causal + self.encoder_cross_attn_layers = encoder_cross_attn_layers + self.max_classes = max_classes + self.max_width = max_width + self.max_height = max_height + self.out_box_size = out_box_size + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + @property + def layers_block_type(self): + return (self.block_types * 100)[: self.num_hidden_layers] + + +class SuryaTableRecTextEncoderConfig(PretrainedConfig): + model_type = "surya_tablerec" + + def __init__( + self, + num_hidden_layers=4, + vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS, + hidden_size=1024, + intermediate_size=4 * 1024, + encoder_hidden_size=1024, + num_attention_heads=16, + lru_width=None, + attention_window_size=16, + conv1d_width=4, + logits_soft_cap=30.0, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + hidden_activation="gelu_pytorch_tanh", + rope_theta=10000.0, + block_types=("attention",), + cross_attn_layers=(0, 1, 2, 3, 4, 5), + self_attn_layers=(0, 1, 2, 3, 4, 5), + global_attn_layers=(0, 1, 2, 3, 4, 5), + attention_dropout=0.0, + num_key_value_heads=16, + attention_bias=False, + w_init_variance_scale=0.01, + init_std=0.02, + tie_word_embeddings=False, + causal=False, + max_width=BOX_DIM + SPECIAL_TOKENS, + max_height=BOX_DIM + SPECIAL_TOKENS, + max_position_embeddings=1024, + **kwargs, + ): + self.num_hidden_layers = num_hidden_layers + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_attention_heads = num_attention_heads + self.lru_width = lru_width if lru_width is not None else hidden_size + self.attention_window_size = attention_window_size + self.conv1d_width = conv1d_width + self.logits_soft_cap = logits_soft_cap + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.block_types = list(block_types) + self.hidden_activation = hidden_activation + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads + if self.num_key_value_heads > self.num_attention_heads: + raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") + self.cross_attn_layers = cross_attn_layers + self.self_attn_layers = self_attn_layers + self.global_attn_layers = global_attn_layers + self.attention_dropout = attention_dropout + self.attention_bias = attention_bias + self.w_init_variance_scale = w_init_variance_scale + self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers + self.init_std = init_std + self.tie_word_embeddings = tie_word_embeddings + self.encoder_hidden_size = encoder_hidden_size + self.causal = causal + self.max_width = max_width + self.max_height = max_height + self.max_position_embeddings = max_position_embeddings + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + + @property + def layers_block_type(self): + return (self.block_types * 100)[: self.num_hidden_layers] \ No newline at end of file diff --git a/surya/model/table_rec/decoder.py b/surya/model/table_rec/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..262183cb90d7140913c5721148a73f95424621c9 --- /dev/null +++ b/surya/model/table_rec/decoder.py @@ -0,0 +1,795 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.utils import ModelOutput + +from surya.model.table_rec.config import SuryaTableRecDecoderConfig, SuryaTableRecTextEncoderConfig +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS + +from surya.settings import settings + +_MAX_SQRT_GRADIENT = 1000.0 + +@dataclass +class TableRecModelOutput(ModelOutput): + bbox_logits: torch.Tensor + class_logits: torch.Tensor | None = None + hidden_states: torch.Tensor | None = None + + +class SuryaTableRecDecoderRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()) + # Llama does x.to(float16) * w whilst SuryaTableRecDecoder is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + output = output * (1.0 + self.weight.float()) + return output.type_as(x) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.eps}" + + +ALL_LAYERNORM_LAYERS.append(SuryaTableRecDecoderRMSNorm) + + +class SuryaTableRecDecoderRotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000, device=None): + super().__init__() + self.dim = dim + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaTableRecDecoder + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class SuryaTableRecDecoderSdpaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper + Modified for GQA + """ + + def __init__(self, config: SuryaTableRecDecoderConfig): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding( + self.head_dim, + base=config.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Encoder attention mask currently ignored + + bsz, q_len, _ = hidden_states.size() + _, v_len, _ = encoder_hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + + if self.key_states is None: + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if use_cache: + self._update_cache(key_states, value_states) + else: + key_states = self.key_states + value_states = self.value_states + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_mask=None, + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.head_dim**-0.5, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output + + def _setup_cache(self, batch_size, device, dtype=None): + # Setup initial caches + self.value_states = None + self.key_states = None + + @torch.no_grad() + def _update_cache(self, key_states, value_states, **cache_kwargs): + self.value_states = value_states + self.key_states = key_states + + +class SuryaTableRecDecoderSdpaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: SuryaTableRecDecoderConfig): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads + + self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) + self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding( + self.head_dim, + base=config.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: bool = False, + window_attn: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Final is bsz, num_attention_heads, seq_len, head_dim + query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if use_cache and hasattr(self, "key_states"): + cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn} + key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + # Mask is batch, head, seq_len, kv_len + causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] + current_cache_position = cache_position[-1].item() if cache_position is not None else None + if current_cache_position and settings.RECOGNITION_STATIC_CACHE: + # Mask out future cache positions + position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device) + position_mask[:, :, :, :current_cache_position + 1] = False + causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + scale=self.head_dim**-0.5, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output + + def _setup_cache(self, batch_size, device, dtype=None): + if dtype is None and self.config.torch_dtype is not None: + dtype = self.config.torch_dtype + dtype = dtype if dtype is not None else torch.float32 + + # Setup initial caches + self.value_states = None + self.key_states = None + + if settings.RECOGNITION_STATIC_CACHE: + cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim) + self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) + self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) + + def _update_static_cache(self, key_states, value_states, **cache_kwargs): + cache_position = cache_kwargs.get("cache_position") + k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device) + + k_out[:, :, cache_position] = key_states.to(k_out.dtype) + v_out[:, :, cache_position] = value_states.to(v_out.dtype) + + self.key_states, self.value_states = k_out, v_out + return k_out, v_out + + def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs): + k_out = key_states + if self.key_states is not None: + k_out = torch.cat([self.key_states, key_states], dim=2) + + v_out = value_states + if self.value_states is not None: + v_out = torch.cat([self.value_states, value_states], dim=2) + + self.key_states, self.value_states = k_out, v_out + return k_out, v_out + + @torch.no_grad() + def _update_cache(self, key_states, value_states, **cache_kwargs): + if settings.RECOGNITION_STATIC_CACHE: + return self._update_static_cache(key_states, value_states, **cache_kwargs) + + return self._update_dynamic_cache(key_states, value_states, **cache_kwargs) + + +class SuryaTableRecDecoderMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_activation is None: + config.hidden_activation = "gelu_pytorch_tanh" + hidden_activation = config.hidden_activation + self.act_fn = ACT2FN[hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class SuryaTableRecDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + super().__init__() + self.cross_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.temporal_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.temporal_block = None + if layer_idx in config.self_attn_layers: + self.temporal_block = SuryaTableRecDecoderSdpaAttention(config) + + self.cross_attn_block = None + if layer_idx in config.cross_attn_layers: + self.cross_attn_block = SuryaTableRecDecoderSdpaCrossAttention(config) + + self.window_attn = layer_idx not in config.global_attn_layers + self.channel_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp_block = SuryaTableRecDecoderMlp(config) + + def forward( + self, + activations: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_attention_mask: torch.Tensor = None, + cache_position: torch.Tensor = None, + use_cache: bool = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raw_activations = activations + + if self.cross_attn_block is not None: + # Do cross-attention on encoder outputs + cross_attn_inputs = self.cross_pre_norm(activations) + cross_attn_path = self.cross_attn_block( + cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache + ) + cross_attn_output = cross_attn_path + raw_activations + else: + cross_attn_output = raw_activations + + if self.temporal_block is not None: + inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences + hidden_states = self.temporal_block( + inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn + ) + + residual = hidden_states + raw_activations + else: + residual = cross_attn_output + + hidden_states = self.channel_pre_norm(residual) + hidden_states = self.mlp_block(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states + + +class SuryaTableRecDecoderPreTrainedModel(PreTrainedModel): + config_class = SuryaTableRecDecoderConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["SuryaTableRecDecoderLayer"] + _skip_keys_device_placement = ["cache"] + _supports_flash_attn_2 = False + _supports_sdpa = False # we can't compare with eager for now + _supports_cache_class = True + _supports_quantized_cache = True + + def _init_weights(self, module): + if isinstance(module, SuryaTableRecDecoderSdpaAttention): + torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std) + torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std) + torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std) + + torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std) + elif isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) + if getattr(module, "bias", None) is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.init_std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _setup_cache(self, config, batch, device, dtype): + layers = getattr(self, "model", self).layers + for layer in layers: + if layer.temporal_block: + layer.temporal_block._setup_cache(batch, device, dtype) + if layer.cross_attn_block: + layer.cross_attn_block._setup_cache(batch, device, dtype) + + def reset_cache(self, batch, device, dtype): + pass + + def _tie_weights(self): + pass + + def tie_weights(self): + pass + + +class LabelEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.vocab_size = config.vocab_size + self.x1_embed = nn.Embedding(config.max_width, config.hidden_size) + self.y1_embed = nn.Embedding(config.max_height, config.hidden_size) + self.x2_embed = nn.Embedding(config.max_width, config.hidden_size) + self.y2_embed = nn.Embedding(config.max_height, config.hidden_size) + self.w_embed = nn.Embedding(config.max_width, config.hidden_size) + self.h_embed = nn.Embedding(config.max_height, config.hidden_size) + self.cx_embed = nn.Embedding(config.max_width, config.hidden_size) + self.cy_embed = nn.Embedding(config.max_height, config.hidden_size) + self.class_embed = nn.Embedding(config.max_classes, config.hidden_size) + self.max_width = config.max_width + self.max_height = config.max_height + self.max_classes = config.max_classes + + def forward(self, labels: torch.LongTensor, input_box_counts: torch.LongTensor): + cx, cy, w, h, class_ = labels.to(torch.long).unbind(dim=-1) + # Shape is (batch_size, num_boxes/seq len, d_model) + x1 = (cx - w // 2).long() + y1 = (cy - h // 2).long() + x2 = (cx + w // 2).long() + y2 = (cy + h // 2).long() + x1 = torch.clamp(x1, 0, self.max_width - 1) + y1 = torch.clamp(y1, 0, self.max_height - 1) + x2 = torch.clamp(x2, 0, self.max_width - 1) + y2 = torch.clamp(y2, 0, self.max_height - 1) + + class_ = torch.clamp(class_, 0, self.max_classes - 1).long() + + w = torch.clamp(w, 0, self.max_width - 1).long() + h = torch.clamp(h, 0, self.max_height - 1).long() + cx = torch.clamp(cx, 0, self.max_width - 1).long() + cy = torch.clamp(cy, 0, self.max_height - 1).long() + + coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2) + class_embeds = self.class_embed(class_) + embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + class_embeds + + return embedded + + +class BboxEmbedding(nn.Module): + def __init__(self, config, embed_positions=False): + super().__init__() + self.x1_embed = nn.Embedding(config.max_width, config.hidden_size) + self.y1_embed = nn.Embedding(config.max_height, config.hidden_size) + self.x2_embed = nn.Embedding(config.max_width, config.hidden_size) + self.y2_embed = nn.Embedding(config.max_height, config.hidden_size) + self.w_embed = nn.Embedding(config.max_width, config.hidden_size) + self.h_embed = nn.Embedding(config.max_height, config.hidden_size) + self.cx_embed = nn.Embedding(config.max_width, config.hidden_size) + self.cy_embed = nn.Embedding(config.max_height, config.hidden_size) + self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.max_width = config.max_width + self.max_height = config.max_height + self.embed_positions = embed_positions + + def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor): + x1, y1, x2, y2 = boxes.unbind(dim=-1) + x1 = torch.clamp(x1, 0, self.max_width - 1).long() + y1 = torch.clamp(y1, 0, self.max_height - 1).long() + x2 = torch.clamp(x2, 0, self.max_width - 1).long() + y2 = torch.clamp(y2, 0, self.max_height - 1).long() + + # Shape is (batch_size, num_boxes/seq len, d_model) + w = x2 - x1 + h = y2 - y1 + # Center x and y in torch long tensors + cx = (x1 + x2) / 2 + cy = (y1 + y2) / 2 + cx = cx.long() + cy = cy.long() + + w = torch.clamp(w, 0, self.max_width - 1).long() + h = torch.clamp(h, 0, self.max_height - 1).long() + cx = torch.clamp(cx, 0, self.max_width - 1).long() + cy = torch.clamp(cy, 0, self.max_height - 1).long() + + coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2) + embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + + # Add in positional embeddings for the boxes and labels + if self.embed_positions: + for j in range(embedded.shape[0]): + box_start = input_box_counts[j, 0] + box_end = input_box_counts[j, 1] - 1 # Skip the sep token + box_count = box_end - box_start + embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count] + + return embedded + + +class SuryaTableRecDecoderModel(SuryaTableRecDecoderPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaTableRecDecoderDecoderLayer`] + + Args: + config: SuryaTableRecDecoderConfig + """ + + def __init__(self, config: SuryaTableRecDecoderConfig, embed_labels=False, embed_positions=True): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.causal = config.causal + + if embed_labels: + self.embed_tokens = LabelEmbedding(config) + else: + self.embed_tokens = BboxEmbedding(config, embed_positions=embed_positions) + + self.layers = nn.ModuleList( + [SuryaTableRecDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.final_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + self.register_buffer( + "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False + ) + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings + def get_input_embeddings(self): + return self.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + input_boxes_counts: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + prefill: bool = False + ) -> Union[Tuple, BaseModelOutputWithNoAttention]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False + + inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts) + hidden_states = inputs_embeds + + if use_cache and prefill: + self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype) + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + + all_hidden_states = () if output_hidden_states else None + for i, residual_block in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache + ) + else: + hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache) + + hidden_states = self.final_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) + + return BaseModelOutputWithNoAttention( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + # Ignore copy + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if not self.causal: + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = max(settings.TABLE_REC_MAX_BOXES, sequence_length) + + diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + causal_mask = diagonal + if sequence_length != 1: + # Select the upper triangular part of the matrix, but unmask current token (the diagonal) + # triu will be the min_dtype, everything else is 0 (attended to) + causal_mask = torch.triu(diagonal, diagonal=1) + + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + # Mask positions in the causal mask that are masked in the attention mask + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if attention_mask is not None and attention_mask.device.type == "cuda": + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class SuryaTableRecDecoder(SuryaTableRecDecoderPreTrainedModel): + _tied_weights_keys = None + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = SuryaTableRecDecoderModel(config, embed_labels=True, embed_positions=False) + self.vocab_size = config.vocab_size + + self.bbox_head = nn.Linear(config.hidden_size, config.max_width * 4, bias=False) + self.class_head = nn.Linear(config.hidden_size, config.max_classes, bias=False) + self.max_width = config.max_width + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + prefill: bool = False, + **kwargs + ) -> Union[Tuple, TableRecModelOutput]: + outputs = self.model( + input_ids=input_ids, + cache_position=cache_position, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_hidden_states=True, + return_dict=True, + prefill=prefill, + ) + + hidden_states = outputs[0] + bbox_logits = self.bbox_head(hidden_states) + class_logits = self.class_head(hidden_states) + bsz, seq_len = class_logits.shape[:2] + bbox_logits = bbox_logits.view(bsz, seq_len, 4, self.max_width) + + return TableRecModelOutput( + bbox_logits=bbox_logits, + class_logits=class_logits, + hidden_states=hidden_states, + ) +@dataclass +class TextEncoderOutput(CausalLMOutput): + hidden_states: torch.FloatTensor = None + + +class SuryaTableRecTextEncoder(SuryaTableRecDecoderPreTrainedModel): + _tied_weights_keys = None + config_class = SuryaTableRecTextEncoderConfig + + def __init__(self, config, **kwargs): + super().__init__(config) + self.model = SuryaTableRecDecoderModel(config, embed_labels=False, embed_positions=True) + self.vocab_size = config.vocab_size + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Ignore copy + def forward( + self, + input_boxes: Optional[torch.LongTensor] = None, + input_boxes_counts: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, CausalLMOutput]: + outputs = self.model( + input_ids=input_boxes, + input_boxes_counts=input_boxes_counts, + cache_position=cache_position, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_hidden_states=True, + return_dict=True, + ) + + return TextEncoderOutput( + hidden_states=outputs.last_hidden_state, + ) \ No newline at end of file diff --git a/surya/model/table_rec/encoderdecoder.py b/surya/model/table_rec/encoderdecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9e2b468abe0296a76599aa5ec66eef5e170beb07 --- /dev/null +++ b/surya/model/table_rec/encoderdecoder.py @@ -0,0 +1,135 @@ +import random +from dataclasses import dataclass +from typing import Optional, Union, Tuple + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import PreTrainedModel, VisionEncoderDecoderConfig, PretrainedConfig +from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput +from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import shift_tokens_right +from surya.model.table_rec.decoder import SuryaTableRecTextEncoder, SuryaTableRecDecoder +from surya.model.recognition.encoder import DonutSwinModel +import torch.nn.functional as F +from transformers.utils import ModelOutput + + +@dataclass +class TableRecOutput(ModelOutput): + row_logits: torch.FloatTensor = None + col_logits: torch.FloatTensor = None + decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +class TableRecEncoderDecoderModel(PreTrainedModel): + config_class = VisionEncoderDecoderConfig + base_model_prefix = "vision_encoder_decoder" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_param_buffer_assignment = False + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + encoder: Optional[PreTrainedModel] = None, + text_encoder: Optional[PreTrainedModel] = None, + decoder: Optional[PreTrainedModel] = None, + ): + # initialize with config + # make sure input & output embeddings is not tied + config.tie_word_embeddings = False + config.decoder.tie_word_embeddings = False + super().__init__(config) + + if encoder is None: + encoder = DonutSwinModel(config.encoder) + + if text_encoder is None: + text_encoder = SuryaTableRecTextEncoder(config.text_encoder, attn_implementation=config._attn_implementation) + + if decoder is None: + decoder = SuryaTableRecDecoder(config.decoder, attn_implementation=config._attn_implementation) + + self.encoder = encoder + self.decoder = decoder + self.text_encoder = text_encoder + + # make sure that the individual model's config refers to the shared config + # so that the updates to the config will be synced + self.encoder.config = self.config.encoder + self.decoder.config = self.config.decoder + self.text_encoder.config = self.config.text_encoder + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def get_output_embeddings(self): + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + return self.decoder.set_output_embeddings(new_embeddings) + + def forward( + self, + decoder_input_ids: torch.LongTensor = None, + decoder_cache_position: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple[torch.FloatTensor], TableRecOutput]: + kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} + + kwargs_decoder = { + argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") + } + + # Decode + decoder_outputs = self.decoder( + input_labels=decoder_input_ids, + input_boxes_counts=None, + cache_position=decoder_cache_position, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs, + encoder_attention_mask=None, + use_cache=use_cache, + **kwargs_decoder, + ) + + return TableRecOutput( + row_logits=decoder_outputs.row_logits, + col_logits=decoder_outputs.col_logits, + decoder_hidden_states=decoder_outputs.hidden_states, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + ): + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) + decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None + input_dict = { + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, + "decoder_input_ids": decoder_inputs["input_ids"], + "encoder_outputs": encoder_outputs, + "past_key_values": decoder_inputs["past_key_values"], + "use_cache": use_cache, + } + return input_dict + + def resize_token_embeddings(self, *args, **kwargs): + raise NotImplementedError( + "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" + " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" + ) + + def _reorder_cache(self, past_key_values, beam_idx): + # apply decoder cache reordering here + return self.decoder._reorder_cache(past_key_values, beam_idx) \ No newline at end of file diff --git a/surya/model/table_rec/model.py b/surya/model/table_rec/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7ffb82887bb9b5d9787bff0054ca99db9f0c4137 --- /dev/null +++ b/surya/model/table_rec/model.py @@ -0,0 +1,34 @@ +from surya.model.recognition.encoder import DonutSwinModel +from surya.model.table_rec.config import SuryaTableRecConfig, SuryaTableRecDecoderConfig, DonutSwinTableRecConfig, \ + SuryaTableRecTextEncoderConfig +from surya.model.table_rec.decoder import SuryaTableRecDecoder, SuryaTableRecTextEncoder +from surya.model.table_rec.encoderdecoder import TableRecEncoderDecoderModel +from surya.settings import settings + + +def load_model(checkpoint=settings.TABLE_REC_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): + + config = SuryaTableRecConfig.from_pretrained(checkpoint) + decoder_config = config.decoder + decoder = SuryaTableRecDecoderConfig(**decoder_config) + config.decoder = decoder + + encoder_config = config.encoder + encoder = DonutSwinTableRecConfig(**encoder_config) + config.encoder = encoder + + text_encoder_config = config.text_encoder + text_encoder = SuryaTableRecTextEncoderConfig(**text_encoder_config) + config.text_encoder = text_encoder + + model = TableRecEncoderDecoderModel.from_pretrained(checkpoint, config=config, torch_dtype=dtype) + + assert isinstance(model.decoder, SuryaTableRecDecoder) + assert isinstance(model.encoder, DonutSwinModel) + assert isinstance(model.text_encoder, SuryaTableRecTextEncoder) + + model = model.to(device) + model = model.eval() + + print(f"Loaded recognition model {checkpoint} on device {device} with dtype {dtype}") + return model \ No newline at end of file diff --git a/surya/model/table_rec/processor.py b/surya/model/table_rec/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..2ce56156bf30278ac066cfc7fa7308b96bd82f12 --- /dev/null +++ b/surya/model/table_rec/processor.py @@ -0,0 +1,248 @@ +import math +from typing import Dict, Union, Optional, List, Iterable + +import cv2 +import torch +from torch import TensorType +from transformers import DonutImageProcessor, DonutProcessor +from transformers.image_processing_utils import BatchFeature +from transformers.image_transforms import pad, normalize +from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size +import numpy as np +from PIL import Image +import PIL +from surya.model.recognition.tokenizer import Byt5LangTokenizer +from surya.settings import settings +from surya.model.table_rec.config import BOX_DIM, SPECIAL_TOKENS + + +def load_processor(): + processor = SuryaProcessor() + processor.image_processor.train = False + processor.image_processor.max_size = settings.TABLE_REC_IMAGE_SIZE + + processor.token_pad_id = 0 + processor.token_eos_id = 1 + processor.token_bos_id = 2 + processor.token_row_id = 3 + processor.token_unused_id = 4 + processor.box_size = (BOX_DIM, BOX_DIM) + processor.special_token_count = SPECIAL_TOKENS + return processor + + +class SuryaImageProcessor(DonutImageProcessor): + def __init__(self, *args, max_size=None, train=False, **kwargs): + super().__init__(*args, **kwargs) + + self.patch_size = kwargs.get("patch_size", (4, 4)) + self.max_size = max_size + self.train = train + + @classmethod + def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4): + max_width, max_height = size["width"], size["height"] + + resized_image = cv2.resize(image, (max_width, max_height), interpolation=interpolation) + resized_image = resized_image.transpose(2, 0, 1) + + return resized_image + + def process_inner(self, images: List[np.ndarray]): + assert images[0].shape[2] == 3 # RGB input images, channel dim last + + # This also applies the right channel dim format, to channel x height x width + images = [SuryaImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images] + assert images[0].shape[0] == 3 # RGB input images, channel dim first + + # Convert to float32 for rescale/normalize + images = [img.astype(np.float32) for img in images] + + # Pads with 255 (whitespace) + # Pad to max size to improve performance + max_size = self.max_size + images = [ + SuryaImageProcessor.pad_image( + image=image, + size=max_size, + input_data_format=ChannelDimension.FIRST, + pad_value=settings.RECOGNITION_PAD_VALUE + ) + for image in images + ] + # Rescale and normalize + for idx in range(len(images)): + images[idx] = images[idx] * self.rescale_factor + images = [ + SuryaImageProcessor.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST) + for img in images + ] + + return images + + def preprocess( + self, + images: ImageInput, + do_resize: bool = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_thumbnail: bool = None, + do_align_long_axis: bool = None, + do_pad: bool = None, + random_padding: bool = False, + do_rescale: bool = None, + rescale_factor: float = None, + do_normalize: bool = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> PIL.Image.Image: + images = make_list_of_images(images) + + # Convert to numpy for later processing steps + images = [np.array(img) for img in images] + images = self.process_inner(images) + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + @classmethod + def pad_image( + cls, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + pad_value: float = 0.0, + ) -> np.ndarray: + output_height, output_width = size["height"], size["width"] + input_height, input_width = get_image_size(image, channel_dim=input_data_format) + + delta_width = output_width - input_width + delta_height = output_height - input_height + + assert delta_width >= 0 and delta_height >= 0 + + pad_top = delta_height // 2 + pad_left = delta_width // 2 + + pad_bottom = delta_height - pad_top + pad_right = delta_width - pad_left + + padding = ((pad_top, pad_bottom), (pad_left, pad_right)) + return pad(image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value) + + @classmethod + def align_long_axis( + cls, + image: np.ndarray, + size: Dict[str, int], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + input_height, input_width = image.shape[:2] + output_height, output_width = size["height"], size["width"] + + if (output_width < output_height and input_width > input_height) or ( + output_width > output_height and input_width < input_height + ): + image = np.rot90(image, 3) + + return image + + @classmethod + def normalize( + cls, + image: np.ndarray, + mean: Union[float, Iterable[float]], + std: Union[float, Iterable[float]], + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + return normalize( + image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs + ) + + +class SuryaProcessor(DonutProcessor): + def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs): + image_processor = SuryaImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT) + if image_processor is None: + raise ValueError("You need to specify an `image_processor`.") + + tokenizer = Byt5LangTokenizer() + super().__init__(image_processor, tokenizer) + self.current_processor = self.image_processor + self._in_target_context_manager = False + self.max_input_boxes = kwargs.get("max_input_boxes", 256) + self.extra_input_boxes = kwargs.get("extra_input_boxes", 32) + + def resize_boxes(self, img, boxes): + width, height = img.size + box_width, box_height = self.box_size + for box in boxes: + # Rescale to 0-1024 + box[0] = box[0] / width * box_width + box[1] = box[1] / height * box_height + box[2] = box[2] / width * box_width + box[3] = box[3] / height * box_height + + if box[0] < 0: + box[0] = 0 + if box[1] < 0: + box[1] = 0 + if box[2] > box_width: + box[2] = box_width + if box[3] > box_height: + box[3] = box_height + + return boxes + + def __call__(self, *args, **kwargs): + images = kwargs.pop("images", []) + boxes = kwargs.pop("boxes", []) + assert len(images) == len(boxes) + + if len(args) > 0: + images = args[0] + args = args[1:] + + for i in range(len(boxes)): + if len(boxes[i]) > self.max_input_boxes: + downsample_ratio = math.ceil(len(boxes[i]) / self.max_input_boxes) + boxes[i] = boxes[i][::downsample_ratio] + + new_boxes = [] + max_len = self.max_input_boxes + self.extra_input_boxes + box_masks = [] + box_ends = [] + for i in range(len(boxes)): + nb = self.resize_boxes(images[i], boxes[i]) + nb = [[b + self.special_token_count for b in box] for box in nb] # shift up + nb = nb[:self.max_input_boxes - 1] + + nb.insert(0, [self.token_row_id] * 4) # Insert special token for max rows/cols + for _ in range(self.extra_input_boxes): + nb.append([self.token_unused_id] * 4) + + pad_length = max_len - len(nb) + box_mask = [1] * len(nb) + [1] * (pad_length) + box_ends.append(len(nb)) + nb = nb + [[self.token_unused_id] * 4] * pad_length + + new_boxes.append(nb) + box_masks.append(box_mask) + + box_ends = torch.tensor(box_ends, dtype=torch.long) + box_starts = torch.tensor([0] * len(boxes), dtype=torch.long) + box_ranges = torch.stack([box_starts, box_ends], dim=1) + + inputs = self.image_processor(images, *args, **kwargs) + inputs["input_boxes"] = torch.tensor(new_boxes, dtype=torch.long) + inputs["input_boxes_mask"] = torch.tensor(box_masks, dtype=torch.long) + inputs["input_boxes_counts"] = box_ranges + return inputs diff --git a/surya/ocr.py b/surya/ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..d750a3ce128ae5354a90145c2d7ac679a711a322 --- /dev/null +++ b/surya/ocr.py @@ -0,0 +1,114 @@ +from copy import deepcopy +from typing import List +from PIL import Image + +from surya.detection import batch_text_detection +from surya.input.processing import slice_polys_from_image, slice_bboxes_from_image, convert_if_not_rgb +from surya.postprocessing.text import sort_text_lines +from surya.recognition import batch_recognition +from surya.schema import TextLine, OCRResult + + +def run_recognition(images: List[Image.Image], langs: List[List[str] | None], rec_model, rec_processor, bboxes: List[List[List[int]]] = None, polygons: List[List[List[List[int]]]] = None, batch_size=None) -> List[OCRResult]: + # Polygons need to be in corner format - [[x1, y1], [x2, y2], [x3, y3], [x4, y4]], bboxes in [x1, y1, x2, y2] format + assert bboxes is not None or polygons is not None + assert len(images) == len(langs), "You need to pass in one list of languages for each image" + + images = convert_if_not_rgb(images) + + slice_map = [] + all_slices = [] + all_langs = [] + for idx, (image, lang) in enumerate(zip(images, langs)): + if polygons is not None: + slices = slice_polys_from_image(image, polygons[idx]) + else: + slices = slice_bboxes_from_image(image, bboxes[idx]) + slice_map.append(len(slices)) + all_slices.extend(slices) + all_langs.extend([deepcopy(lang)] * len(slices)) + + rec_predictions, _ = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size) + + predictions_by_image = [] + slice_start = 0 + for idx, (image, lang) in enumerate(zip(images, langs)): + slice_end = slice_start + slice_map[idx] + image_lines = rec_predictions[slice_start:slice_end] + slice_start = slice_end + + text_lines = [] + for i in range(len(image_lines)): + if polygons is not None: + poly = polygons[idx][i] + else: + bbox = bboxes[idx][i] + poly = [[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]] + + text_lines.append(TextLine( + text=image_lines[i], + polygon=poly + )) + + pred = OCRResult( + text_lines=text_lines, + languages=lang, + image_bbox=[0, 0, image.size[0], image.size[1]] + ) + predictions_by_image.append(pred) + + return predictions_by_image + + +def run_ocr(images: List[Image.Image], langs: List[List[str] | None], det_model, det_processor, rec_model, rec_processor, batch_size=None, highres_images: List[Image.Image] | None = None) -> List[OCRResult]: + images = convert_if_not_rgb(images) + highres_images = convert_if_not_rgb(highres_images) if highres_images is not None else [None] * len(images) + det_predictions = batch_text_detection(images, det_model, det_processor) + + all_slices = [] + slice_map = [] + all_langs = [] + + for idx, (det_pred, image, highres_image, lang) in enumerate(zip(det_predictions, images, highres_images, langs)): + polygons = [p.polygon for p in det_pred.bboxes] + if highres_image: + width_scaler = highres_image.size[0] / image.size[0] + height_scaler = highres_image.size[1] / image.size[1] + scaled_polygons = [[[int(p[0] * width_scaler), int(p[1] * height_scaler)] for p in polygon] for polygon in polygons] + slices = slice_polys_from_image(highres_image, scaled_polygons) + else: + slices = slice_polys_from_image(image, polygons) + slice_map.append(len(slices)) + all_langs.extend([lang] * len(slices)) + all_slices.extend(slices) + + rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size) + + predictions_by_image = [] + slice_start = 0 + for idx, (image, det_pred, lang) in enumerate(zip(images, det_predictions, langs)): + slice_end = slice_start + slice_map[idx] + image_lines = rec_predictions[slice_start:slice_end] + line_confidences = confidence_scores[slice_start:slice_end] + slice_start = slice_end + + assert len(image_lines) == len(det_pred.bboxes) + + lines = [] + for text_line, confidence, bbox in zip(image_lines, line_confidences, det_pred.bboxes): + lines.append(TextLine( + text=text_line, + polygon=bbox.polygon, + bbox=bbox.bbox, + confidence=confidence + )) + + lines = sort_text_lines(lines) + + predictions_by_image.append(OCRResult( + text_lines=lines, + languages=lang, + image_bbox=det_pred.image_bbox + )) + + return predictions_by_image diff --git a/surya/ordering.py b/surya/ordering.py new file mode 100644 index 0000000000000000000000000000000000000000..820bd742975aeaab83dfdb7e2ab2ef832058bb50 --- /dev/null +++ b/surya/ordering.py @@ -0,0 +1,141 @@ +from copy import deepcopy +from typing import List +import torch +from PIL import Image + +from surya.input.processing import convert_if_not_rgb +from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel +from surya.schema import OrderBox, OrderResult +from surya.settings import settings +from tqdm import tqdm +import numpy as np + + +def get_batch_size(): + batch_size = settings.ORDER_BATCH_SIZE + if batch_size is None: + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "mps": + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "cuda": + batch_size = 32 + return batch_size + + +def rank_elements(arr): + enumerated_and_sorted = sorted(enumerate(arr), key=lambda x: x[1]) + rank = [0] * len(arr) + + for rank_value, (original_index, value) in enumerate(enumerated_and_sorted): + rank[original_index] = rank_value + + return rank + + +def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[OrderResult]: + assert all([isinstance(image, Image.Image) for image in images]) + assert len(images) == len(bboxes) + if batch_size is None: + batch_size = get_batch_size() + + + output_order = [] + for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"): + batch_bboxes = deepcopy(bboxes[i:i+batch_size]) + batch_images = images[i:i+batch_size] + batch_images = [image.convert("RGB") for image in batch_images] # also copies the images + + orig_sizes = [image.size for image in batch_images] + model_inputs = processor(images=batch_images, boxes=batch_bboxes) + + batch_pixel_values = model_inputs["pixel_values"] + batch_bboxes = model_inputs["input_boxes"] + batch_bbox_mask = model_inputs["input_boxes_mask"] + batch_bbox_counts = model_inputs["input_boxes_counts"] + + batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device) + batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device) + batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device) + batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device) + + token_count = 0 + past_key_values = None + encoder_outputs = None + batch_predictions = [[] for _ in range(len(batch_images))] + done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device) + + with torch.inference_mode(): + while token_count < settings.ORDER_MAX_BOXES: + return_dict = model( + pixel_values=batch_pixel_values, + decoder_input_boxes=batch_bboxes, + decoder_input_boxes_mask=batch_bbox_mask, + decoder_input_boxes_counts=batch_bbox_counts, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + ) + logits = return_dict["logits"].detach() + + last_tokens = [] + last_token_mask = [] + min_val = torch.finfo(model.dtype).min + for j in range(logits.shape[0]): + label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token + new_logits = logits[j, -1] + new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once + new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes + pred = int(torch.argmax(new_logits, dim=-1).item()) + + # Add one to avoid colliding with the 1000 height/width token for bboxes + last_tokens.append([[pred + processor.box_size["height"] + 1] * 4]) + if len(batch_predictions[j]) == label_count - 1: # Minus one since we're appending the final label + last_token_mask.append([0]) + batch_predictions[j].append(pred) + done[j] = True + elif len(batch_predictions[j]) < label_count - 1: + last_token_mask.append([1]) + batch_predictions[j].append(pred) # Get rank prediction for given position + else: + last_token_mask.append([0]) + + if done.all(): + break + + past_key_values = return_dict["past_key_values"] + encoder_outputs = (return_dict["encoder_last_hidden_state"],) + + batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device) + token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device) + batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1) + token_count += 1 + + for j, row_pred in enumerate(batch_predictions): + row_bboxes = bboxes[i+j] + assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}" + + orig_size = orig_sizes[j] + ranks = [0] * len(row_bboxes) + + for box_idx in range(len(row_bboxes)): + ranks[row_pred[box_idx]] = box_idx + + order_boxes = [] + for row_bbox, rank in zip(row_bboxes, ranks): + order_box = OrderBox( + bbox=row_bbox, + position=rank, + ) + order_boxes.append(order_box) + + result = OrderResult( + bboxes=order_boxes, + image_bbox=[0, 0, orig_size[0], orig_size[1]], + ) + output_order.append(result) + return output_order + + + + + + diff --git a/surya/postprocessing/affinity.py b/surya/postprocessing/affinity.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb538cbe6fed233e03c03e4bd198a7e896d6165 --- /dev/null +++ b/surya/postprocessing/affinity.py @@ -0,0 +1,165 @@ +from typing import List + +import cv2 +import numpy as np + +from PIL import Image, ImageDraw + +from surya.postprocessing.util import get_line_angle, rescale_bbox +from surya.schema import ColumnLine + + +def get_detected_lines_sobel(image, vertical=True): + # Apply Sobel operator with a kernel size of 3 to detect vertical edges + if vertical: + dx = 1 + dy = 0 + else: + dx = 0 + dy = 1 + + sobelx = cv2.Sobel(image, cv2.CV_32F, dx, dy, ksize=3) + + + # Absolute Sobel (to capture both edges) + abs_sobelx = np.absolute(sobelx) + + # Convert to 8-bit image + scaled_sobel = np.uint8(255 * abs_sobelx / np.max(abs_sobelx)) + + kernel = np.ones((20, 1), np.uint8) + eroded = cv2.erode(scaled_sobel, kernel, iterations=1) + scaled_sobel = cv2.dilate(eroded, kernel, iterations=3) + + return scaled_sobel + + +def get_detected_lines(image, slope_tol_deg=2, vertical=False, horizontal=False) -> List[ColumnLine]: + assert not (vertical and horizontal) + new_image = image.astype(np.float32) * 255 # Convert to 0-255 range + if vertical or horizontal: + new_image = get_detected_lines_sobel(new_image, vertical) + new_image = new_image.astype(np.uint8) + + edges = cv2.Canny(new_image, 150, 200, apertureSize=3) + if vertical: + max_gap = 100 + min_length = 10 + else: + max_gap = 10 + min_length = 4 + + lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=150, minLineLength=min_length, maxLineGap=max_gap) + + line_info = [] + if lines is not None: + for line in lines: + vertical_line = False + horizontal_line = False + x1, y1, x2, y2 = line[0] + bbox = [x1, y1, x2, y2] + + if x2 == x1: + vertical_line = True + else: + line_angle = get_line_angle(x1, y1, x2, y2) + if 90 - slope_tol_deg < line_angle < 90 + slope_tol_deg: + vertical_line = True + elif -90 - slope_tol_deg < line_angle < -90 + slope_tol_deg: + vertical_line = True + elif -slope_tol_deg < line_angle < slope_tol_deg: + horizontal_line = True + + if bbox[3] < bbox[1]: + bbox[1], bbox[3] = bbox[3], bbox[1] + if bbox[2] < bbox[0]: + bbox[0], bbox[2] = bbox[2], bbox[0] + row = ColumnLine(bbox=bbox, vertical=vertical_line, horizontal=horizontal_line) + line_info.append(row) + + if vertical: + line_info = [line for line in line_info if line.vertical] + + if horizontal: + line_info = [line for line in line_info if line.horizontal] + + return line_info + + +def draw_lines_on_image(line_info: List[ColumnLine], img): + draw = ImageDraw.Draw(img) + + for line in line_info: + divisor = 20 + if line.horizontal: + divisor = 200 + x1, y1, x2, y2 = [x // divisor * divisor for x in line.bbox] + if line.vertical: + draw.line((x1, y1, x2, y2), fill="red", width=3) + + return img + + +def get_vertical_lines(image, processor_size, image_size, divisor=20, x_tolerance=40, y_tolerance=20) -> List[ColumnLine]: + vertical_lines = get_detected_lines(image, vertical=True) + for line in vertical_lines: + line.rescale_bbox(processor_size, image_size) + vertical_lines = sorted(vertical_lines, key=lambda x: x.bbox[0]) + for line in vertical_lines: + line.round_bbox(divisor) + + # Merge adjacent line segments together + to_remove = [] + for i, line in enumerate(vertical_lines): + for j, line2 in enumerate(vertical_lines): + if j <= i: + continue + if line.bbox[0] != line2.bbox[0]: + continue + + expanded_line1 = [line.bbox[0], line.bbox[1] - y_tolerance, line.bbox[2], + line.bbox[3] + y_tolerance] + + line1_points = set(range(int(expanded_line1[1]), int(expanded_line1[3]))) + line2_points = set(range(int(line2.bbox[1]), int(line2.bbox[3]))) + intersect_y = len(line1_points.intersection(line2_points)) > 0 + + if intersect_y: + vertical_lines[j].bbox[1] = min(line.bbox[1], line2.bbox[1]) + vertical_lines[j].bbox[3] = max(line.bbox[3], line2.bbox[3]) + to_remove.append(i) + + vertical_lines = [line for i, line in enumerate(vertical_lines) if i not in to_remove] + + # Remove redundant segments + to_remove = [] + for i, line in enumerate(vertical_lines): + if i in to_remove: + continue + for j, line2 in enumerate(vertical_lines): + if j <= i or j in to_remove: + continue + close_in_x = abs(line.bbox[0] - line2.bbox[0]) < x_tolerance + line1_points = set(range(int(line.bbox[1]), int(line.bbox[3]))) + line2_points = set(range(int(line2.bbox[1]), int(line2.bbox[3]))) + + intersect_y = len(line1_points.intersection(line2_points)) > 0 + + if close_in_x and intersect_y: + # Keep the longer line and extend it + if len(line2_points) > len(line1_points): + vertical_lines[j].bbox[1] = min(line.bbox[1], line2.bbox[1]) + vertical_lines[j].bbox[3] = max(line.bbox[3], line2.bbox[3]) + to_remove.append(i) + else: + vertical_lines[i].bbox[1] = min(line.bbox[1], line2.bbox[1]) + vertical_lines[i].bbox[3] = max(line.bbox[3], line2.bbox[3]) + to_remove.append(j) + + vertical_lines = [line for i, line in enumerate(vertical_lines) if i not in to_remove] + + if len(vertical_lines) > 0: + # Always start with top left of page + vertical_lines[0].bbox[1] = 0 + + return vertical_lines \ No newline at end of file diff --git a/surya/postprocessing/fonts.py b/surya/postprocessing/fonts.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e18789c356413ac544da345a158f253fa7365b --- /dev/null +++ b/surya/postprocessing/fonts.py @@ -0,0 +1,24 @@ +from typing import List, Optional +import os +import requests + +from surya.settings import settings + + +def get_font_path(langs: Optional[List[str]] = None) -> str: + font_path = settings.RECOGNITION_RENDER_FONTS["all"] + if langs is not None: + for k in settings.RECOGNITION_RENDER_FONTS: + if k in langs and len(langs) == 1: + font_path = settings.RECOGNITION_RENDER_FONTS[k] + break + + if not os.path.exists(font_path): + os.makedirs(os.path.dirname(font_path), exist_ok=True) + font_dl_path = f"{settings.RECOGNITION_FONT_DL_BASE}/{os.path.basename(font_path)}" + with requests.get(font_dl_path, stream=True) as r, open(font_path, 'wb') as f: + r.raise_for_status() + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + + return font_path \ No newline at end of file diff --git a/surya/postprocessing/heatmap.py b/surya/postprocessing/heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5e7487fb906fc4ff25bdebf67249364bb22c10 --- /dev/null +++ b/surya/postprocessing/heatmap.py @@ -0,0 +1,224 @@ +from typing import List, Tuple + +import numpy as np +import cv2 +import math +from PIL import ImageDraw, ImageFont + +from surya.postprocessing.fonts import get_font_path +from surya.postprocessing.util import rescale_bbox +from surya.schema import PolygonBox +from surya.settings import settings +from surya.postprocessing.text import get_text_size + + +def keep_largest_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: + new_boxes = [] + for box_obj in boxes: + box = box_obj.bbox + box_area = (box[2] - box[0]) * (box[3] - box[1]) + contained = False + for other_box_obj in boxes: + if other_box_obj.polygon == box_obj.polygon: + continue + + other_box = other_box_obj.bbox + other_box_area = (other_box[2] - other_box[0]) * (other_box[3] - other_box[1]) + if box == other_box: + continue + # find overlap percentage + overlap = box_obj.intersection_pct(other_box_obj) + if overlap > .9 and box_area < other_box_area: + contained = True + break + if not contained: + new_boxes.append(box_obj) + return new_boxes + + +def clean_contained_boxes(boxes: List[PolygonBox]) -> List[PolygonBox]: + new_boxes = [] + for box_obj in boxes: + box = box_obj.bbox + contained = False + for other_box_obj in boxes: + if other_box_obj.polygon == box_obj.polygon: + continue + + other_box = other_box_obj.bbox + if box == other_box: + continue + if box[0] >= other_box[0] and box[1] >= other_box[1] and box[2] <= other_box[2] and box[3] <= other_box[3]: + contained = True + break + if not contained: + new_boxes.append(box_obj) + return new_boxes + + +def get_dynamic_thresholds(linemap, text_threshold, low_text, typical_top10_avg=0.7): + # Find average intensity of top 10% pixels + flat_map = linemap.ravel() + top_10_count = int(len(flat_map) * 0.9) + avg_intensity = np.mean(np.partition(flat_map, top_10_count)[top_10_count:]) + scaling_factor = np.clip(avg_intensity / typical_top10_avg, 0, 1) ** (1 / 2) + + low_text = np.clip(low_text * scaling_factor, 0.1, 0.6) + text_threshold = np.clip(text_threshold * scaling_factor, 0.15, 0.8) + + return text_threshold, low_text + + +def detect_boxes(linemap, text_threshold, low_text): + # From CRAFT - https://github.com/clovaai/CRAFT-pytorch + # Modified to return boxes and for speed, accuracy + img_h, img_w = linemap.shape + + text_threshold, low_text = get_dynamic_thresholds(linemap, text_threshold, low_text) + + text_score_comb = (linemap > low_text).astype(np.uint8) + label_count, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb, connectivity=4) + + det = [] + confidences = [] + max_confidence = 0 + + for k in range(1, label_count): + # size filtering + size = stats[k, cv2.CC_STAT_AREA] + if size < 10: + continue + + # make segmentation map + x, y, w, h = stats[k, [cv2.CC_STAT_LEFT, cv2.CC_STAT_TOP, cv2.CC_STAT_WIDTH, cv2.CC_STAT_HEIGHT]] + + try: + niter = int(np.sqrt(min(w, h))) + except ValueError: + niter = 0 + + buffer = 1 + sx, sy = max(0, x - niter - buffer), max(0, y - niter - buffer) + ex, ey = min(img_w, x + w + niter + buffer), min(img_h, y + h + niter + buffer) + + mask = (labels[sy:ey, sx:ex] == k) + selected_linemap = linemap[sy:ey, sx:ex][mask] + line_max = np.max(selected_linemap) + + # thresholding + if line_max < text_threshold: + continue + + segmap = mask.astype(np.uint8) + + ksize = buffer + niter + kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(ksize, ksize)) + selected_segmap = cv2.dilate(segmap, kernel) + + # make box + indices = np.nonzero(selected_segmap) + x_inds = indices[1] + sx + y_inds = indices[0] + sy + np_contours = np.column_stack((x_inds, y_inds)) + rectangle = cv2.minAreaRect(np_contours) + box = cv2.boxPoints(rectangle) + + # align diamond-shape + w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) + box_ratio = max(w, h) / (min(w, h) + 1e-5) + if abs(1 - box_ratio) <= 0.1: + l, r = min(np_contours[:, 0]), max(np_contours[:, 0]) + t, b = min(np_contours[:, 1]), max(np_contours[:, 1]) + box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) + + # make clock-wise order + startidx = box.sum(axis=1).argmin() + box = np.roll(box, 4-startidx, 0) + box = np.array(box) + + confidence = line_max + max_confidence = max(max_confidence, line_max) + + confidences.append(confidence) + det.append(box) + + if max_confidence > 0: + confidences = [c / max_confidence for c in confidences] + return det, confidences + + +def get_detected_boxes(textmap, text_threshold=None, low_text=None) -> List[PolygonBox]: + if text_threshold is None: + text_threshold = settings.DETECTOR_TEXT_THRESHOLD + + if low_text is None: + low_text = settings.DETECTOR_BLANK_THRESHOLD + + textmap = textmap.copy() + textmap = textmap.astype(np.float32) + boxes, confidences = detect_boxes(textmap, text_threshold, low_text) + # From point form to box form + boxes = [PolygonBox(polygon=box, confidence=confidence) for box, confidence in zip(boxes, confidences)] + return boxes + + +def get_and_clean_boxes(textmap, processor_size, image_size, text_threshold=None, low_text=None) -> List[PolygonBox]: + bboxes = get_detected_boxes(textmap, text_threshold, low_text) + for bbox in bboxes: + bbox.rescale(processor_size, image_size) + bbox.fit_to_bounds([0, 0, image_size[0], image_size[1]]) + + bboxes = clean_contained_boxes(bboxes) + return bboxes + + + +def draw_bboxes_on_image(bboxes, image, labels=None, label_font_size=10, color: str | list='red'): + polys = [] + for bb in bboxes: + # Clockwise polygon + poly = [ + [bb[0], bb[1]], + [bb[2], bb[1]], + [bb[2], bb[3]], + [bb[0], bb[3]] + ] + polys.append(poly) + + return draw_polys_on_image(polys, image, labels, label_font_size=label_font_size, color=color) + + +def draw_polys_on_image(corners, image, labels=None, box_padding=-1, label_offset=1, label_font_size=10, color: str | list='red'): + draw = ImageDraw.Draw(image) + font_path = get_font_path() + label_font = ImageFont.truetype(font_path, label_font_size) + + for i in range(len(corners)): + poly = corners[i] + poly = [(int(p[0]), int(p[1])) for p in poly] + draw.polygon(poly, outline=color[i] if isinstance(color, list) else color, width=1) + + if labels is not None: + label = labels[i] + text_position = ( + min([p[0] for p in poly]) + label_offset, + min([p[1] for p in poly]) + label_offset + ) + text_size = get_text_size(label, label_font) + box_position = ( + text_position[0] - box_padding + label_offset, + text_position[1] - box_padding + label_offset, + text_position[0] + text_size[0] + box_padding + label_offset, + text_position[1] + text_size[1] + box_padding + label_offset + ) + draw.rectangle(box_position, fill="white") + draw.text( + text_position, + label, + fill=color[i] if isinstance(color, list) else color, + font=label_font + ) + + return image + + diff --git a/surya/postprocessing/math/latex.py b/surya/postprocessing/math/latex.py new file mode 100644 index 0000000000000000000000000000000000000000..b07e5fb8e51200dbb32e0055e8ea1de04b008caf --- /dev/null +++ b/surya/postprocessing/math/latex.py @@ -0,0 +1,125 @@ +import re +from ftfy import fix_text + + +def contains_math(text): + return text.startswith("$") or text.endswith("$") + + +def fix_math(text): + # Fix any issues with the text + text = fix_text(text) + + # Remove LaTeX labels and references + text = remove_labels(text) + text = replace_katex_invalid(text) + text = fix_fences(text) + return text + + +def remove_labels(text): + pattern = r'\\label\{[^}]*\}' + text = re.sub(pattern, '', text) + + ref_pattern = r'\\ref\{[^}]*\}' + text = re.sub(ref_pattern, '', text) + + pageref_pattern = r'\\pageref\{[^}]*\}' + text = re.sub(pageref_pattern, '', text) + return text + + +def replace_katex_invalid(string): + # KaTeX cannot render all LaTeX, so we need to replace some things + string = re.sub(r'\\tag\{.*?\}', '', string) + string = re.sub(r'\\(?:Bigg?|bigg?)\{(.*?)\}', r'\1', string) + string = re.sub(r'\\quad\\mbox\{(.*?)\}', r'\1', string) + string = re.sub(r'\\mbox\{(.*?)\}', r'\1', string) + string = remove_inner_dollars(string) + return string + + +def remove_inner_dollars(text): + def replace_dollar(match): + # Replace single $ with nothing, keep $$ intact + math_block = match.group(1) + return '$$' + math_block.replace('$', '') + '$$' + + pattern = r'\$\$(.*?)\$\$' + return re.sub(pattern, replace_dollar, text, flags=re.DOTALL) + + +def extract_latex_with_positions(text): + pattern = r'(\$\$.*?\$\$|\$.*?\$)' + matches = [] + for match in re.finditer(pattern, text, re.DOTALL): + matches.append((match.group(), match.start(), match.end())) + return matches + + +def slice_latex(text): + # Extract LaTeX blocks along with their positions + latex_blocks_with_positions = extract_latex_with_positions(text) + + chunks = [] + last_position = 0 + for block, start, end in latex_blocks_with_positions: + # Add text before the current LaTeX block, if any + if start > last_position: + chunks.append({"text": text[last_position:start], "type": "text"}) + # Add the LaTeX block + chunks.append({"text": block, "type": "latex"}) + last_position = end + # Add remaining text after the last LaTeX block, if any + if last_position < len(text): + chunks.append({"text": text[last_position:], "type": "text"}) + + return chunks + + +def is_latex(text): + latex_patterns = [ + r'\\(?:begin|end)\{[a-zA-Z]*\}', + r'\$.*?\$', + r'\$\$.*?\$\$', + r'\\[a-zA-Z]+', + r'\\[^a-zA-Z]', + ] + + combined_pattern = '|'.join(latex_patterns) + if re.search(combined_pattern, text, re.DOTALL): + return True + + return False + + +def fix_fences(text): + if text.startswith("$$") and not text.endswith("$$"): + if text[-1] == "$": + text += "$" + else: + text += "$$" + + if text.endswith("$$") and not text.startswith("$$"): + if text[0] == "$": + text = "$" + text + else: + text = "$$" + text + + if text.startswith("$") and not text.endswith("$"): + text = "$" + text + "$$" + + if text.endswith("$") and not text.startswith("$"): + text = "$$" + text + "$" + + return text + + +def strip_fences(text): + while text.startswith("$"): + text = text[1:] + while text.endswith("$"): + text = text[:-1] + return text + + diff --git a/surya/postprocessing/math/render.py b/surya/postprocessing/math/render.py new file mode 100644 index 0000000000000000000000000000000000000000..761334a0bd923e48478075949885ed1a829ac2d9 --- /dev/null +++ b/surya/postprocessing/math/render.py @@ -0,0 +1,88 @@ +from playwright.sync_api import sync_playwright +from PIL import Image +import io + + +def latex_to_pil(latex_code, target_width, target_height, fontsize=18): + html_template = """ + + + + + + + + +
{content}
+ + + + """ + + formatted_latex = latex_code.replace('\n', '\\n').replace('"', '\\"') + with sync_playwright() as p: + browser = p.chromium.launch() + page = browser.new_page() + page.set_viewport_size({'width': target_width, 'height': target_height}) + + while fontsize <= 30: + html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize)) + page.set_content(html_content) + + dimensions = page.evaluate("""() => { + const render = document.getElementById('content'); + return { + width: render.offsetWidth, + height: render.offsetHeight + }; + }""") + + if dimensions['width'] >= target_width or dimensions['height'] >= target_height: + fontsize -= 1 + break + else: + fontsize += 1 + + html_content = html_template.replace("{content}", formatted_latex).replace("{fontsize}", str(fontsize)) + page.set_content(html_content) + + screenshot_bytes = page.screenshot() + browser.close() + + image_stream = io.BytesIO(screenshot_bytes) + pil_image = Image.open(image_stream) + pil_image.load() + return pil_image \ No newline at end of file diff --git a/surya/postprocessing/text.py b/surya/postprocessing/text.py new file mode 100644 index 0000000000000000000000000000000000000000..542a80cc05de109d5f8ac86b4d976527a9142e24 --- /dev/null +++ b/surya/postprocessing/text.py @@ -0,0 +1,118 @@ +import os +from typing import List, Tuple + +import requests +from PIL import Image, ImageDraw, ImageFont + +from surya.postprocessing.fonts import get_font_path +from surya.schema import TextLine +from surya.settings import settings +from surya.postprocessing.math.latex import is_latex + + +def sort_text_lines(lines: List[TextLine] | List[dict], tolerance=1.25): + # Sorts in reading order. Not 100% accurate, this should only + # be used as a starting point for more advanced sorting. + vertical_groups = {} + for line in lines: + group_key = round(line.bbox[1] if isinstance(line, TextLine) else line["bbox"][1] / tolerance) * tolerance + if group_key not in vertical_groups: + vertical_groups[group_key] = [] + vertical_groups[group_key].append(line) + + # Sort each group horizontally and flatten the groups into a single list + sorted_lines = [] + for _, group in sorted(vertical_groups.items()): + sorted_group = sorted(group, key=lambda x: x.bbox[0] if isinstance(x, TextLine) else x["bbox"][0]) + sorted_lines.extend(sorted_group) + + return sorted_lines + + +def truncate_repetitions(text: str, min_len=15): + # From nougat, with some cleanup + if len(text) < 2 * min_len: + return text + + # try to find a length at which the tail is repeating + max_rep_len = None + for rep_len in range(min_len, int(len(text) / 2)): + # check if there is a repetition at the end + same = True + for i in range(0, rep_len): + if text[len(text) - rep_len - i - 1] != text[len(text) - i - 1]: + same = False + break + + if same: + max_rep_len = rep_len + + if max_rep_len is None: + return text + + lcs = text[-max_rep_len:] + + # remove all but the last repetition + text_to_truncate = text + while text_to_truncate.endswith(lcs): + text_to_truncate = text_to_truncate[:-max_rep_len] + + return text[:len(text_to_truncate)] + + +def get_text_size(text, font): + im = Image.new(mode="P", size=(0, 0)) + draw = ImageDraw.Draw(im) + _, _, width, height = draw.textbbox((0, 0), text=text, font=font) + return width, height + + +def render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size): + font = ImageFont.truetype(font_path, box_font_size) + text_width, text_height = get_text_size(text, font) + while (text_width > bbox_width or text_height > bbox_height) and box_font_size > 6: + box_font_size = box_font_size - 1 + font = ImageFont.truetype(font_path, box_font_size) + text_width, text_height = get_text_size(text, font) + + # Calculate text position (centered in bbox) + text_width, text_height = get_text_size(text, font) + x = s_bbox[0] + y = s_bbox[1] + (bbox_height - text_height) / 2 + + draw.text((x, y), text, fill="black", font=font) + + +def render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path): + try: + from surya.postprocessing.math.render import latex_to_pil + box_font_size = max(10, min(int(.2 * bbox_height), 24)) + img = latex_to_pil(text, bbox_width, bbox_height, fontsize=box_font_size) + img.thumbnail((bbox_width, bbox_height)) + image.paste(img, (s_bbox[0], s_bbox[1])) + except Exception as e: + print(f"Failed to render math: {e}") + box_font_size = max(10, min(int(.75 * bbox_height), 24)) + render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size) + + +def draw_text_on_image(bboxes, texts, image_size: Tuple[int, int], langs: List[str], font_path=None, max_font_size=60, res_upscale=2, has_math=False): + if font_path is None: + font_path = get_font_path(langs) + new_image_size = (image_size[0] * res_upscale, image_size[1] * res_upscale) + image = Image.new('RGB', new_image_size, color='white') + draw = ImageDraw.Draw(image) + + for bbox, text in zip(bboxes, texts): + s_bbox = [int(coord * res_upscale) for coord in bbox] + bbox_width = s_bbox[2] - s_bbox[0] + bbox_height = s_bbox[3] - s_bbox[1] + + # Shrink the text to fit in the bbox if needed + if has_math and is_latex(text): + render_math(image, draw, text, s_bbox, bbox_width, bbox_height, font_path) + else: + box_font_size = max(6, min(int(.75 * bbox_height), max_font_size)) + render_text(draw, text, s_bbox, bbox_width, bbox_height, font_path, box_font_size) + + return image diff --git a/surya/postprocessing/util.py b/surya/postprocessing/util.py new file mode 100644 index 0000000000000000000000000000000000000000..e38c9403e319226a120c939043b60aae0f9fc143 --- /dev/null +++ b/surya/postprocessing/util.py @@ -0,0 +1,48 @@ +import math +import copy + + +def get_line_angle(x1, y1, x2, y2): + slope = (y2 - y1) / (x2 - x1) + + angle_radians = math.atan(slope) + angle_degrees = math.degrees(angle_radians) + + return angle_degrees + + +def rescale_bbox(bbox, processor_size, image_size): + page_width, page_height = processor_size + + img_width, img_height = image_size + width_scaler = img_width / page_width + height_scaler = img_height / page_height + + new_bbox = copy.deepcopy(bbox) + new_bbox[0] = int(new_bbox[0] * width_scaler) + new_bbox[1] = int(new_bbox[1] * height_scaler) + new_bbox[2] = int(new_bbox[2] * width_scaler) + new_bbox[3] = int(new_bbox[3] * height_scaler) + return new_bbox + + +def rescale_bboxes(bboxes, orig_size, new_size): + return [rescale_bbox(bbox, orig_size, new_size) for bbox in bboxes] + + +def rescale_point(point, processor_size, image_size): + # Point is in x, y format + page_width, page_height = processor_size + + img_width, img_height = image_size + width_scaler = img_width / page_width + height_scaler = img_height / page_height + + new_point = copy.deepcopy(point) + new_point[0] = int(new_point[0] * width_scaler) + new_point[1] = int(new_point[1] * height_scaler) + return new_point + + +def rescale_points(points, processor_size, image_size): + return [rescale_point(point, processor_size, image_size) for point in points] \ No newline at end of file diff --git a/surya/recognition.py b/surya/recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..8fcd56c0f0e9c256dde83fad91ed46bc5897717f --- /dev/null +++ b/surya/recognition.py @@ -0,0 +1,186 @@ +from typing import List +import torch +from PIL import Image + +from surya.postprocessing.math.latex import fix_math, contains_math +from surya.postprocessing.text import truncate_repetitions +from surya.settings import settings +from tqdm import tqdm +import numpy as np +import torch.nn.functional as F + + +def get_batch_size(): + batch_size = settings.RECOGNITION_BATCH_SIZE + if batch_size is None: + batch_size = 32 + if settings.TORCH_DEVICE_MODEL == "mps": + batch_size = 64 # 12GB RAM max + if settings.TORCH_DEVICE_MODEL == "cuda": + batch_size = 512 + return batch_size + + +def pad_to_batch_size(tensor, batch_size): + current_batch_size = tensor.shape[0] + if current_batch_size >= batch_size: + return tensor + + pad_size = batch_size - current_batch_size + padding = (0, 0) * (tensor.dim() - 1) + (0, pad_size) + + return F.pad(tensor, padding, mode='constant', value=0) + + +def batch_recognition(images: List, languages: List[List[str] | None], model, processor, batch_size=None): + assert all([isinstance(image, Image.Image) for image in images]) + assert len(images) == len(languages) + + if len(images) == 0: + return [], [] + + if batch_size is None: + batch_size = get_batch_size() + + # Sort images by width, so similar length ones go together + sorted_pairs = sorted(enumerate(images), key=lambda x: x[1].width, reverse=False) + indices, images = zip(*sorted_pairs) + indices = list(indices) + images = list(images) + + output_text = [] + confidences = [] + for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"): + batch_images = images[i:i+batch_size] + batch_images = [image.convert("RGB") for image in batch_images] # also copies the images + + batch_langs = languages[i:i+batch_size] + has_math = [lang and "_math" in lang for lang in batch_langs] + + processed_batch = processor(text=[""] * len(batch_images), images=batch_images, langs=batch_langs) + + batch_pixel_values = processed_batch["pixel_values"] + batch_langs = processed_batch["langs"] + batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs] + max_input_length = max([len(tokens) for tokens in batch_decoder_input]) + + # Pad decoder input to max length if needed, to ensure we can convert to a tensor + for token_idx in range(len(batch_decoder_input)): + lang_len = len(batch_decoder_input[token_idx]) + if lang_len < max_input_length: + batch_decoder_input[token_idx] = [processor.tokenizer.pad_id] * (max_input_length - lang_len) + batch_decoder_input[token_idx] + + current_batch_size = len(batch_pixel_values) + + batch_pixel_values = torch.tensor(np.stack(batch_pixel_values, axis=0), dtype=model.dtype, device=model.device) + batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device) + + token_count = 0 + inference_token_count = batch_decoder_input.shape[-1] + batch_predictions = [[] for _ in range(current_batch_size)] + + decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) + model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) + + sequence_scores = None + all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) + encoder_hidden_states = None + + with torch.no_grad(): # inference_mode doesn't work with torch.compile + encoder_batch_size = batch_size // settings.RECOGNITION_ENCODER_BATCH_DIVISOR + 1 + for z in range(0, batch_pixel_values.shape[0], encoder_batch_size): + encoder_pixel_values = batch_pixel_values[z:min(z + encoder_batch_size, batch_pixel_values.shape[0])] + encoder_hidden_states_batch = model.encoder(pixel_values=encoder_pixel_values).last_hidden_state + if encoder_hidden_states is None: + encoder_hidden_states = encoder_hidden_states_batch + else: + encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_batch], dim=0) + + text_encoder_input_ids = torch.arange( + model.text_encoder.config.query_token_count, + device=encoder_hidden_states.device, + dtype=torch.long + ).unsqueeze(0).expand(encoder_hidden_states.size(0), -1) + + encoder_text_hidden_states = model.text_encoder( + input_ids=text_encoder_input_ids, + cache_position=None, + attention_mask=None, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=None, + use_cache=False + ).hidden_states + del encoder_hidden_states + + if settings.RECOGNITION_STATIC_CACHE: + # Pad inputs to max batch size for static cache + encoder_text_hidden_states = pad_to_batch_size(encoder_text_hidden_states, batch_size) + batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) + + while token_count < settings.RECOGNITION_MAX_TOKENS - 1: + is_prefill = token_count == 0 + #TODO: add attention mask + return_dict = model.decoder( + input_ids=batch_decoder_input, + encoder_hidden_states=encoder_text_hidden_states, + cache_position=decoder_position_ids, + use_cache=True, + prefill=is_prefill + ) + + decoder_position_ids = decoder_position_ids[-1:] + 1 + logits = return_dict["logits"][:current_batch_size] # Ignore batch padding + aux_logits = return_dict.get("aux_logits", None) + + preds = torch.argmax(logits[:, -1], dim=-1) + scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1) + done = (preds == processor.tokenizer.eos_id) | (preds == processor.tokenizer.pad_id) + done = done + all_done = all_done | done + + if is_prefill: + sequence_scores = scores + else: + scores = scores.masked_fill(all_done, 0) + sequence_scores = torch.cat([sequence_scores, scores], dim=1) + + if all_done.all(): + break + + batch_decoder_input = preds.unsqueeze(1) + + for j, (pred, status) in enumerate(zip(preds, all_done)): + if not status: + batch_predictions[j].append(int(pred)) + + token_count += inference_token_count + inference_token_count = batch_decoder_input.shape[-1] + max_position_id = torch.max(decoder_position_ids).item() + decoder_position_ids = torch.ones_like(batch_decoder_input[0, :], dtype=torch.int64, device=model.device).cumsum(0) - 1 + max_position_id + + if settings.RECOGNITION_STATIC_CACHE: + batch_decoder_input = pad_to_batch_size(batch_decoder_input, batch_size) + + sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1) + detected_text = processor.tokenizer.batch_decode(batch_predictions) + detected_text = [truncate_repetitions(dt) for dt in detected_text] + + # Postprocess to fix LaTeX output (add $$ signs, etc) + detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)] + output_text.extend(detected_text) + confidences.extend(sequence_scores.tolist()) + + del encoder_text_hidden_states + + output_text = sorted(zip(indices, output_text), key=lambda x: x[0]) + confidences = sorted(zip(indices, confidences), key=lambda x: x[0]) + output_text = [text for _, text in output_text] + confidences = [conf for _, conf in confidences] + return output_text, confidences + + + + + + diff --git a/surya/schema.py b/surya/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..880a77e588d4b1a59dff8cc265cb4294818cea20 --- /dev/null +++ b/surya/schema.py @@ -0,0 +1,192 @@ +import copy +from typing import List, Tuple, Any, Optional + +from pydantic import BaseModel, field_validator, computed_field + +from surya.postprocessing.util import rescale_bbox + + +class PolygonBox(BaseModel): + polygon: List[List[float]] + confidence: Optional[float] = None + + @field_validator('polygon') + @classmethod + def check_elements(cls, v: List[List[float]]) -> List[List[float]]: + if len(v) != 4: + raise ValueError('corner must have 4 elements') + + for corner in v: + if len(corner) != 2: + raise ValueError('corner must have 2 elements') + return v + + @property + def height(self): + return self.bbox[3] - self.bbox[1] + + @property + def width(self): + return self.bbox[2] - self.bbox[0] + + @property + def area(self): + return self.width * self.height + + @computed_field + @property + def bbox(self) -> List[float]: + box = [self.polygon[0][0], self.polygon[0][1], self.polygon[1][0], self.polygon[2][1]] + if box[0] > box[2]: + box[0], box[2] = box[2], box[0] + if box[1] > box[3]: + box[1], box[3] = box[3], box[1] + return box + + def rescale(self, processor_size, image_size): + # Point is in x, y format + page_width, page_height = processor_size + + img_width, img_height = image_size + width_scaler = img_width / page_width + height_scaler = img_height / page_height + + new_corners = copy.deepcopy(self.polygon) + for corner in new_corners: + corner[0] = int(corner[0] * width_scaler) + corner[1] = int(corner[1] * height_scaler) + self.polygon = new_corners + + def fit_to_bounds(self, bounds): + new_corners = copy.deepcopy(self.polygon) + for corner in new_corners: + corner[0] = max(min(corner[0], bounds[2]), bounds[0]) + corner[1] = max(min(corner[1], bounds[3]), bounds[1]) + self.polygon = new_corners + + def merge(self, other): + x1 = min(self.bbox[0], other.bbox[0]) + y1 = min(self.bbox[1], other.bbox[1]) + x2 = max(self.bbox[2], other.bbox[2]) + y2 = max(self.bbox[3], other.bbox[3]) + self.polygon = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] + + def intersection_area(self, other, x_margin=0, y_margin=0): + x_overlap = max(0, min(self.bbox[2] + x_margin, other.bbox[2] + x_margin) - max(self.bbox[0] - x_margin, other.bbox[0] - x_margin)) + y_overlap = max(0, min(self.bbox[3] + y_margin, other.bbox[3] + y_margin) - max(self.bbox[1] - y_margin, other.bbox[1] - y_margin)) + return x_overlap * y_overlap + + def intersection_pct(self, other, x_margin=0, y_margin=0): + assert 0 <= x_margin <= 1 + assert 0 <= y_margin <= 1 + if self.area == 0: + return 0 + + if x_margin: + x_margin = int(min(self.width, other.width) * x_margin) + if y_margin: + y_margin = int(min(self.height, other.height) * y_margin) + + intersection = self.intersection_area(other, x_margin, y_margin) + return intersection / self.area + + +class Bbox(BaseModel): + bbox: List[float] + + @field_validator('bbox') + @classmethod + def check_4_elements(cls, v: List[float]) -> List[float]: + if len(v) != 4: + raise ValueError('bbox must have 4 elements') + return v + + def rescale_bbox(self, orig_size, new_size): + self.bbox = rescale_bbox(self.bbox, orig_size, new_size) + + def round_bbox(self, divisor): + self.bbox = [x // divisor * divisor for x in self.bbox] + + @property + def height(self): + return self.bbox[3] - self.bbox[1] + + @property + def width(self): + return self.bbox[2] - self.bbox[0] + + @property + def area(self): + return self.width * self.height + + @property + def polygon(self): + return [[self.bbox[0], self.bbox[1]], [self.bbox[2], self.bbox[1]], [self.bbox[2], self.bbox[3]], [self.bbox[0], self.bbox[3]]] + + @property + def center(self): + return [(self.bbox[0] + self.bbox[2]) / 2, (self.bbox[1] + self.bbox[3]) / 2] + + def intersection_pct(self, other): + if self.area == 0: + return 0 + + x_overlap = max(0, min(self.bbox[2], other.bbox[2]) - max(self.bbox[0], other.bbox[0])) + y_overlap = max(0, min(self.bbox[3], other.bbox[3]) - max(self.bbox[1], other.bbox[1])) + intersection = x_overlap * y_overlap + return intersection / self.area + +class LayoutBox(PolygonBox): + label: str + + +class OrderBox(Bbox): + position: int + + +class ColumnLine(Bbox): + vertical: bool + horizontal: bool + + +class TextLine(PolygonBox): + text: str + confidence: Optional[float] = None + + +class OCRResult(BaseModel): + text_lines: List[TextLine] + languages: List[str] | None = None + image_bbox: List[float] + + +class TextDetectionResult(BaseModel): + bboxes: List[PolygonBox] + vertical_lines: List[ColumnLine] + heatmap: Any + affinity_map: Any + image_bbox: List[float] + + +class LayoutResult(BaseModel): + bboxes: List[LayoutBox] + segmentation_map: Any + image_bbox: List[float] + + +class OrderResult(BaseModel): + bboxes: List[OrderBox] + image_bbox: List[float] + + +class TableCell(Bbox): + row_id: int | None = None + col_id: int | None = None + text: str | None = None + + +class TableResult(BaseModel): + cells: List[TableCell] + rows: List[TableCell] + cols: List[TableCell] + image_bbox: List[float] diff --git a/surya/settings.py b/surya/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..12f99ad78ede2028f50793aae5946637e7f50453 --- /dev/null +++ b/surya/settings.py @@ -0,0 +1,96 @@ +from typing import Dict, Optional + +from dotenv import find_dotenv +from pydantic import computed_field +from pydantic_settings import BaseSettings +import torch +import os + + +class Settings(BaseSettings): + # General + TORCH_DEVICE: Optional[str] = None + IMAGE_DPI: int = 96 # Used for detection, layout, reading order + IMAGE_DPI_HIGHRES: int = 192 # Used for OCR, table rec + IN_STREAMLIT: bool = False # Whether we're running in streamlit + ENABLE_EFFICIENT_ATTENTION: bool = True # Usually keep True, but if you get CUDA errors, setting to False can help + + # Paths + DATA_DIR: str = "data" + RESULT_DIR: str = "results" + BASE_DIR: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + FONT_DIR: str = os.path.join(BASE_DIR, "static", "fonts") + + @computed_field + def TORCH_DEVICE_MODEL(self) -> str: + if self.TORCH_DEVICE is not None: + return self.TORCH_DEVICE + + if torch.cuda.is_available(): + return "cuda" + + if torch.backends.mps.is_available(): + return "mps" + + return "cpu" + + # Text detection + DETECTOR_BATCH_SIZE: Optional[int] = None # Defaults to 2 for CPU/MPS, 32 otherwise + DETECTOR_MODEL_CHECKPOINT: str = "vikp/surya_det3" + DETECTOR_BENCH_DATASET_NAME: str = "vikp/doclaynet_bench" + DETECTOR_IMAGE_CHUNK_HEIGHT: int = 1400 # Height at which to slice images vertically + DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text) + DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank) + DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(8, os.cpu_count()) # Number of workers for postprocessing + DETECTOR_MIN_PARALLEL_THRESH: int = 3 # Minimum number of images before we parallelize + + # Text recognition + RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec2" + RECOGNITION_MAX_TOKENS: int = 175 + RECOGNITION_BATCH_SIZE: Optional[int] = None # Defaults to 8 for CPU/MPS, 256 otherwise + RECOGNITION_IMAGE_SIZE: Dict = {"height": 256, "width": 896} + RECOGNITION_RENDER_FONTS: Dict[str, str] = { + "all": os.path.join(FONT_DIR, "GoNotoCurrent-Regular.ttf"), + "zh": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), + "ja": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), + "ko": os.path.join(FONT_DIR, "GoNotoCJKCore.ttf"), + } + RECOGNITION_FONT_DL_BASE: str = "https://github.com/satbyy/go-noto-universal/releases/download/v7.0" + RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench" + RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255 + RECOGNITION_STATIC_CACHE: bool = False # Static cache for torch compile + RECOGNITION_ENCODER_BATCH_DIVISOR: int = 2 # Divisor for batch size in decoder + + # Layout + LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout3" + LAYOUT_BENCH_DATASET_NAME: str = "vikp/publaynet_bench" + + # Ordering + ORDER_MODEL_CHECKPOINT: str = "vikp/surya_order" + ORDER_IMAGE_SIZE: Dict = {"height": 1024, "width": 1024} + ORDER_MAX_BOXES: int = 256 + ORDER_BATCH_SIZE: Optional[int] = None # Defaults to 4 for CPU/MPS, 32 otherwise + ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench" + + # Table Rec + TABLE_REC_MODEL_CHECKPOINT: str = "vikp/surya_tablerec" + TABLE_REC_IMAGE_SIZE: Dict = {"height": 640, "width": 640} + TABLE_REC_MAX_BOXES: int = 512 + TABLE_REC_MAX_ROWS: int = 384 + TABLE_REC_BATCH_SIZE: Optional[int] = None + TABLE_REC_BENCH_DATASET_NAME: str = "vikp/fintabnet_bench" + + # Tesseract (for benchmarks only) + TESSDATA_PREFIX: Optional[str] = None + + @computed_field + @property + def MODEL_DTYPE(self) -> torch.dtype: + return torch.float32 if self.TORCH_DEVICE_MODEL == "cpu" else torch.float16 + + class Config: + env_file = find_dotenv("local.env") + extra = "ignore" + + +settings = Settings() \ No newline at end of file diff --git a/surya/tables.py b/surya/tables.py new file mode 100644 index 0000000000000000000000000000000000000000..9e533447d51afe97d76fdb4ab962f2c022f08805 --- /dev/null +++ b/surya/tables.py @@ -0,0 +1,259 @@ +from collections import defaultdict +from copy import deepcopy +from typing import List, Dict +import torch +from PIL import Image + +from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel +from surya.schema import TableResult, TableCell, Bbox +from surya.settings import settings +from tqdm import tqdm +import numpy as np +from surya.model.table_rec.config import SPECIAL_TOKENS + + +def get_batch_size(): + batch_size = settings.TABLE_REC_BATCH_SIZE + if batch_size is None: + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "mps": + batch_size = 8 + if settings.TORCH_DEVICE_MODEL == "cuda": + batch_size = 64 + return batch_size + + +def sort_bboxes(bboxes, tolerance=1): + vertical_groups = {} + for block in bboxes: + group_key = round(block["bbox"][1] / tolerance) * tolerance + if group_key not in vertical_groups: + vertical_groups[group_key] = [] + vertical_groups[group_key].append(block) + + # Sort each group horizontally and flatten the groups into a single list + sorted_page_blocks = [] + for _, group in sorted(vertical_groups.items()): + sorted_group = sorted(group, key=lambda x: x["bbox"][0]) + sorted_page_blocks.extend(sorted_group) + + return sorted_page_blocks + + +def is_rotated(rows, cols): + # Determine if the table is rotated by looking at row and column width / height ratios + # Rows should have a >1 ratio, cols <1 + widths = sum([r.width for r in rows]) + heights = sum([c.height for c in rows]) + 1 + r_ratio = widths / heights + + widths = sum([c.width for c in cols]) + heights = sum([r.height for r in cols]) + 1 + c_ratio = widths / heights + + return r_ratio * 2 < c_ratio + +def batch_table_recognition(images: List, table_cells: List[List[Dict]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]: + assert all([isinstance(image, Image.Image) for image in images]) + assert len(images) == len(table_cells) + if batch_size is None: + batch_size = get_batch_size() + + output_order = [] + for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables"): + batch_table_cells = deepcopy(table_cells[i:i+batch_size]) + batch_table_cells = [sort_bboxes(page_bboxes) for page_bboxes in batch_table_cells] # Sort bboxes before passing in + batch_list_bboxes = [[block["bbox"] for block in page] for page in batch_table_cells] + + batch_images = images[i:i+batch_size] + batch_images = [image.convert("RGB") for image in batch_images] # also copies the images + current_batch_size = len(batch_images) + + orig_sizes = [image.size for image in batch_images] + model_inputs = processor(images=batch_images, boxes=deepcopy(batch_list_bboxes)) + + batch_pixel_values = model_inputs["pixel_values"] + batch_bboxes = model_inputs["input_boxes"] + batch_bbox_mask = model_inputs["input_boxes_mask"] + batch_bbox_counts = model_inputs["input_boxes_counts"] + + batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device) + batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device) + batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device) + batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device) + + # Setup inputs for the decoder + batch_decoder_input = [[[model.config.decoder.bos_token_id] * 5] for _ in range(current_batch_size)] + batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device) + inference_token_count = batch_decoder_input.shape[1] + + max_tokens = min(batch_bbox_counts[:, 1].max().item(), settings.TABLE_REC_MAX_BOXES) + decoder_position_ids = torch.ones_like(batch_decoder_input[0, :, 0], dtype=torch.int64, device=model.device).cumsum(0) - 1 + model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) + model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) + + batch_predictions = [[] for _ in range(current_batch_size)] + + with torch.inference_mode(): + encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state + text_encoder_hidden_states = model.text_encoder( + input_boxes=batch_bboxes, + input_boxes_counts=batch_bbox_counts, + cache_position=None, + attention_mask=batch_bbox_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=None, + use_cache=False + ).hidden_states + + token_count = 0 + all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) + + while token_count < max_tokens: + is_prefill = token_count == 0 + return_dict = model.decoder( + input_ids=batch_decoder_input, + encoder_hidden_states=text_encoder_hidden_states, + cache_position=decoder_position_ids, + use_cache=True, + prefill=is_prefill + ) + + decoder_position_ids = decoder_position_ids[-1:] + 1 + box_logits = return_dict["bbox_logits"][:, -1, :].detach() + rowcol_logits = return_dict["class_logits"][:, -1, :].detach() + + rowcol_preds = torch.argmax(rowcol_logits, dim=-1) + box_preds = torch.argmax(box_logits, dim=-1) + + done = (rowcol_preds == processor.tokenizer.eos_id) | (rowcol_preds == processor.tokenizer.pad_id) + done = done + all_done = all_done | done + + if all_done.all(): + break + + batch_decoder_input = torch.cat([box_preds.unsqueeze(1), rowcol_preds.unsqueeze(1).unsqueeze(1)], dim=-1) + + for j, (pred, status) in enumerate(zip(batch_decoder_input, all_done)): + if not status: + batch_predictions[j].append(pred[0].tolist()) + + token_count += inference_token_count + inference_token_count = batch_decoder_input.shape[1] + + for j, (preds, input_cells, orig_size) in enumerate(zip(batch_predictions, batch_table_cells, orig_sizes)): + img_w, img_h = orig_size + width_scaler = img_w / model.config.decoder.out_box_size + height_scaler = img_h / model.config.decoder.out_box_size + + # cx, cy to corners + for i, pred in enumerate(preds): + w = pred[2] / 2 + h = pred[3] / 2 + x1 = pred[0] - w + y1 = pred[1] - h + x2 = pred[0] + w + y2 = pred[1] + h + class_ = int(pred[4] - SPECIAL_TOKENS) + + preds[i] = [x1 * width_scaler, y1 * height_scaler, x2 * width_scaler, y2 * height_scaler, class_] + + # Get rows and columns + bb_rows = [p[:4] for p in preds if p[4] == 0] + bb_cols = [p[:4] for p in preds if p[4] == 1] + + rows = [] + cols = [] + for row_idx, row in enumerate(bb_rows): + cell = TableCell( + bbox=row, + row_id=row_idx + ) + rows.append(cell) + + for col_idx, col in enumerate(bb_cols): + cell = TableCell( + bbox=col, + col_id=col_idx, + ) + cols.append(cell) + + # Assign cells to rows/columns + cells = [] + for cell in input_cells: + max_intersection = 0 + row_pred = None + for row_idx, row in enumerate(rows): + intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(row) + if intersection_pct > max_intersection: + max_intersection = intersection_pct + row_pred = row_idx + + max_intersection = 0 + col_pred = None + for col_idx, col in enumerate(cols): + intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(col) + if intersection_pct > max_intersection: + max_intersection = intersection_pct + col_pred = col_idx + + cells.append( + TableCell( + bbox=cell["bbox"], + text=cell.get("text"), + row_id=row_pred, + col_id=col_pred + ) + ) + + rotated = is_rotated(rows, cols) + for cell in cells: + if cell.row_id is None: + closest_row = None + closest_row_dist = None + for cell2 in cells: + if cell2.row_id is None: + continue + if rotated: + cell_y_center = cell.center[0] + cell2_y_center = cell2.center[0] + else: + cell_y_center = cell.center[1] + cell2_y_center = cell2.center[1] + y_dist = abs(cell_y_center - cell2_y_center) + if closest_row_dist is None or y_dist < closest_row_dist: + closest_row = cell2.row_id + closest_row_dist = y_dist + cell.row_id = closest_row + + if cell.col_id is None: + closest_col = None + closest_col_dist = None + for cell2 in cells: + if cell2.col_id is None: + continue + if rotated: + cell_x_center = cell.center[1] + cell2_x_center = cell2.center[1] + else: + cell_x_center = cell.center[0] + cell2_x_center = cell2.center[0] + + x_dist = abs(cell2_x_center - cell_x_center) + if closest_col_dist is None or x_dist < closest_col_dist: + closest_col = cell2.col_id + closest_col_dist = x_dist + + cell.col_id = closest_col + + result = TableResult( + cells=cells, + rows=rows, + cols=cols, + image_bbox=[0, 0, img_w, img_h], + ) + + output_order.append(result) + + return output_order \ No newline at end of file diff --git a/table_recognition.py b/table_recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..4e03335c8a84274461c784a8ff1e48051b41f2bc --- /dev/null +++ b/table_recognition.py @@ -0,0 +1,146 @@ +import pypdfium2 as pdfium # Needs to be on top to avoid warning +import os +import argparse +import copy +import json +from collections import defaultdict + +from surya.detection import batch_text_detection +from surya.input.load import load_from_folder, load_from_file +from surya.input.pdflines import get_table_blocks +from surya.layout import batch_layout_detection +from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor +from surya.model.table_rec.model import load_model as load_model +from surya.model.table_rec.processor import load_processor +from surya.tables import batch_table_recognition +from surya.postprocessing.heatmap import draw_bboxes_on_image +from surya.settings import settings +from surya.postprocessing.util import rescale_bboxes, rescale_bbox + + +def main(): + parser = argparse.ArgumentParser(description="Find reading order of an input file or folder (PDFs or image).") + parser.add_argument("input_path", type=str, help="Path to pdf or image file or folder to find reading order in.") + parser.add_argument("--results_dir", type=str, help="Path to JSON file with layout results.", default=os.path.join(settings.RESULT_DIR, "surya")) + parser.add_argument("--max", type=int, help="Maximum number of pages to process.", default=None) + parser.add_argument("--images", action="store_true", help="Save images of detected layout bboxes.", default=False) + parser.add_argument("--detect_boxes", action="store_true", help="Detect table boxes.", default=False) + parser.add_argument("--skip_table_detection", action="store_true", help="Tables are already cropped, so don't re-detect tables.", default=False) + args = parser.parse_args() + + model = load_model() + processor = load_processor() + + layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT) + + det_model = load_det_model() + det_processor = load_det_processor() + + if os.path.isdir(args.input_path): + images, _, _ = load_from_folder(args.input_path, args.max) + highres_images, names, text_lines = load_from_folder(args.input_path, args.max, dpi=settings.IMAGE_DPI_HIGHRES, load_text_lines=True) + folder_name = os.path.basename(args.input_path) + else: + images, _, _ = load_from_file(args.input_path, args.max) + highres_images, names, text_lines = load_from_file(args.input_path, args.max, dpi=settings.IMAGE_DPI_HIGHRES, load_text_lines=True) + folder_name = os.path.basename(args.input_path).split(".")[0] + + pnums = [] + prev_name = None + for i, name in enumerate(names): + if prev_name is None or prev_name != name: + pnums.append(0) + else: + pnums.append(pnums[-1] + 1) + + prev_name = name + + line_predictions = batch_text_detection(images, det_model, det_processor) + layout_predictions = batch_layout_detection(images, layout_model, layout_processor, line_predictions) + table_cells = [] + + table_imgs = [] + table_counts = [] + + for layout_pred, text_line, img, highres_img in zip(layout_predictions, text_lines, images, highres_images): + # The table may already be cropped + if args.skip_table_detection: + table_imgs.append(highres_img) + table_counts.append(1) + page_table_imgs = [highres_img] + highres_bbox = [[0, 0, highres_img.size[0], highres_img.size[1]]] + else: + # The bbox for the entire table + bbox = [l.bbox for l in layout_pred.bboxes if l.label == "Table"] + # Number of tables per page + table_counts.append(len(bbox)) + + if len(bbox) == 0: + continue + + page_table_imgs = [] + highres_bbox = [] + for bb in bbox: + highres_bb = rescale_bbox(bb, img.size, highres_img.size) + page_table_imgs.append(highres_img.crop(highres_bb)) + highres_bbox.append(highres_bb) + + table_imgs.extend(page_table_imgs) + + # The text cells inside each table + table_blocks = get_table_blocks(highres_bbox, text_line, highres_img.size) if text_line is not None else None + if text_line is None or args.detect_boxes or any(len(tb) == 0 for tb in table_blocks): + det_results = batch_text_detection(page_table_imgs, det_model, det_processor,) + cell_bboxes = [[{"bbox": tb.bbox, "text": None} for tb in det_result.bboxes] for det_result in det_results] + table_cells.extend(cell_bboxes) + else: + table_cells.extend(table_blocks) + + table_preds = batch_table_recognition(table_imgs, table_cells, model, processor) + result_path = os.path.join(args.results_dir, folder_name) + os.makedirs(result_path, exist_ok=True) + + img_idx = 0 + prev_count = 0 + table_predictions = defaultdict(list) + for i in range(sum(table_counts)): + while i >= prev_count + table_counts[img_idx]: + prev_count += table_counts[img_idx] + img_idx += 1 + + pred = table_preds[i] + orig_name = names[img_idx] + pnum = pnums[img_idx] + table_img = table_imgs[i] + + out_pred = pred.model_dump() + out_pred["page"] = pnum + 1 + table_idx = i - prev_count + out_pred["table_idx"] = table_idx + table_predictions[orig_name].append(out_pred) + + if args.images: + boxes = [l.bbox for l in pred.cells] + labels = [f"{l.row_id}/{l.col_id}" for l in pred.cells] + bbox_image = draw_bboxes_on_image(boxes, copy.deepcopy(table_img), labels=labels, label_font_size=20) + bbox_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_cells.png")) + + rows = [l.bbox for l in pred.rows] + cols = [l.bbox for l in pred.cols] + row_labels = [f"Row {l.row_id}" for l in pred.rows] + col_labels = [f"Col {l.col_id}" for l in pred.cols] + + rc_image = copy.deepcopy(table_img) + rc_image = draw_bboxes_on_image(rows, rc_image, labels=row_labels, label_font_size=20, color="blue") + rc_image = draw_bboxes_on_image(cols, rc_image, labels=col_labels, label_font_size=20, color="red") + rc_image.save(os.path.join(result_path, f"{name}_page{pnum + 1}_table{table_idx}_rc.png")) + + with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f: + json.dump(table_predictions, f, ensure_ascii=False) + + print(f"Wrote results to {result_path}") + + +if __name__ == "__main__": + main() \ No newline at end of file