Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
from ultralytics import YOLO # Make sure this import works in your Hugging Face environment | |
from io import BytesIO | |
import numpy as np | |
import pandas as pd | |
from transformers import VisionEncoderDecoderModel, TrOCRProcessor | |
def load_ocr_model(): | |
""" | |
Load and cache the ocr model and processor | |
""" | |
model = VisionEncoderDecoderModel.from_pretrained('edesaras/TROCR_finetuned_on_CSTA', cache_dir='./models/TrOCR') | |
processor = TrOCRProcessor.from_pretrained("edesaras/TROCR_finetuned_on_CSTA", cache_dir='./models/TrOCR') | |
return model, processor | |
def load_model(): | |
""" | |
Load and cache the model | |
""" | |
model = YOLO('./models/YOLO/weights.pt') | |
return model | |
def predict(model, image, font_size, line_width): | |
""" | |
Run inference and return annotated image | |
""" | |
results = model.predict(image) | |
r = results[0] | |
im_bgr = r.plot(conf=False, pil=True, font_size=font_size, line_width=line_width) # Returns a PIL image if pil=True | |
im_rgb = Image.fromarray(im_bgr[..., ::-1]) # Convert BGR to RGB | |
return im_rgb, r | |
def extract_text_patches(result, image): | |
image = np.array(image) | |
text_bboxes = [] | |
for i, label in enumerate([result.names[id.item()] for id in result.boxes.cls]): | |
if label == 'text': | |
bbox = result.boxes.xyxy[i] | |
text_bboxes.append([round(i.item()) for i in bbox]) | |
crops = [] | |
for box in text_bboxes: | |
xmin, ymin, xmax, ymax = box | |
crop_img = image[ymin:ymax, xmin:xmax] | |
crops.append(crop_img) | |
return crops, text_bboxes | |
def ocr_predict(model, processor, crops): | |
pixel_values = processor(crops, return_tensors="pt").pixel_values | |
# Generate text with TrOCR | |
generated_ids = model.generate(pixel_values) | |
texts = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
return texts | |
def file_uploader_cb(model, ocr_model, ocr_processor, uploaded_file, font_size, line_width): | |
image = Image.open(uploaded_file).convert("RGB") | |
col1, col2 = st.columns(2) | |
with col1: | |
# Display Uploaded image | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
# Perform inference | |
annotated_img, result = predict(model, image, font_size, line_width) | |
with col2: | |
# Display the prediction | |
st.image(annotated_img, caption='Prediction', use_column_width=True) | |
# write image to memory buffer for download | |
imbuffer = BytesIO() | |
annotated_img.save(imbuffer, format="JPEG") | |
st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="upload") | |
st.subheader('Transcription') | |
crops, text_bboxes = extract_text_patches(result, image) | |
texts = ocr_predict(ocr_model, ocr_processor, crops) | |
transcription_df = pd.DataFrame(zip(texts, *np.array(text_bboxes).T, [st.image(crop) for crop in crops]), | |
columns=['Transcription', 'xmin', 'ymin', 'xmax', 'ymax', 'Image']) | |
st.dataframe(transcription_df) | |
def image_capture_cb(model, ocr_model, ocr_processor, capture, font_size, line_width, col): | |
image = Image.open(capture).convert("RGB") | |
# Perform inference | |
annotated_img, result = predict(model, image, font_size, line_width) | |
with col: | |
# Display the prediction | |
st.image(annotated_img, caption='Prediction', use_column_width=True) | |
# write image to memory buffer for download | |
imbuffer = BytesIO() | |
annotated_img.save(imbuffer, format="JPEG") | |
st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="capture") | |
st.subheader('Transcription') | |
crops, text_bboxes = extract_text_patches(result, image) | |
texts = ocr_predict(ocr_model, ocr_processor, crops) | |
transcription_df = pd.DataFrame(zip(texts, *np.array(text_bboxes).T), | |
columns=['Transcription', 'xmin', 'ymin', 'xmax', 'ymax']) | |
st.dataframe(transcription_df) | |