edesaras commited on
Commit
4929692
1 Parent(s): 84e12ef

Added OCR Model, replaced old YOLO model with new one trained using rotation augmentation, streamlit tabs -> multipage app

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ main.ipynb
3
+ blankexample.jpeg
Hello.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ if __name__ == "__main__":
4
+ # set page configurations and display/annotation options
5
+ st.set_page_config(
6
+ page_title="Circuit Sketch Recognizer",
7
+ layout="wide"
8
+ )
9
+
10
+ st.title("Circuit Sketch Recognition")
11
+ col1, col2 = st.columns(2)
12
+ with col1:
13
+ st.image('example1.jpg', use_column_width=True, caption='Example 1')
14
+ with col2:
15
+ st.image('example2.jpg', use_column_width=True, caption='Example 2')
weights.pt → models/YOLO/weights.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0791285b924f954f0370a13739ca87e2569a90bf935c0afdd69797f9dc2bbf0a
3
- size 52120385
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f93346972611fd027af6c1b1dfc9cd818f48e794d682466e2ef3ba6042721df
3
+ size 52163457
pages/Capture_Image.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import sys
3
+ import os
4
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
5
+ from utils import load_model, image_capture_cb, load_ocr_model
6
+
7
+ if __name__ == "__main__":
8
+ # set page configurations and display/annotation options
9
+ st.set_page_config(
10
+ page_title="Circuit Sketch Recognizer",
11
+ layout="wide"
12
+ )
13
+
14
+ with st.sidebar:
15
+ font_size = st.slider(label="Font Size", min_value=6, max_value=64, step=1, value=24)
16
+ line_width = st.slider(label="Bounding Box Line Thickness", min_value=1, max_value=8, step=1, value=3)
17
+
18
+ model = load_model()
19
+ ocr_model, ocr_processor = load_ocr_model()
20
+
21
+ # Camera Input allows user to take a picture
22
+ col1, col2 = st.columns(2)
23
+ with col1:
24
+ capture = st.camera_input("Take a picture with Camera")
25
+ if capture is not None:
26
+ image_capture_cb(model, ocr_model, ocr_processor, capture, font_size, line_width, col2)
pages/Upload_An_Image.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import sys
3
+ import os
4
+ sys.path.insert(1, os.path.join(sys.path[0], '..'))
5
+ from utils import load_model, file_uploader_cb, load_ocr_model
6
+
7
+ if __name__ == "__main__":
8
+ # set page configurations and display/annotation options
9
+ st.set_page_config(
10
+ page_title="Circuit Sketch Recognizer",
11
+ layout="wide"
12
+ )
13
+
14
+ with st.sidebar:
15
+ font_size = st.slider(label="Font Size", min_value=6, max_value=64, step=1, value=24)
16
+ line_width = st.slider(label="Bounding Box Line Thickness", min_value=1, max_value=8, step=1, value=3)
17
+
18
+ model = load_model()
19
+ ocr_model, ocr_processor = load_ocr_model()
20
+
21
+ # File uploader allows user to add their own image
22
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
23
+ if uploaded_file is not None:
24
+ file_uploader_cb(model, ocr_model, ocr_processor, uploaded_file, font_size, line_width)
app.py → utils.py RENAMED
@@ -1,16 +1,26 @@
1
  import streamlit as st
2
  from PIL import Image
3
- import numpy as np
4
  from ultralytics import YOLO # Make sure this import works in your Hugging Face environment
5
  from io import BytesIO
 
 
 
6
 
 
 
 
 
 
 
 
 
7
 
8
  @st.cache_resource
9
  def load_model():
10
  """
11
  Load and cache the model
12
  """
13
- model = YOLO("weights.pt") # Adjust path if needed
14
  return model
15
 
16
  def predict(model, image, font_size, line_width):
@@ -21,16 +31,37 @@ def predict(model, image, font_size, line_width):
21
  r = results[0]
22
  im_bgr = r.plot(conf=False, pil=True, font_size=font_size, line_width=line_width) # Returns a PIL image if pil=True
23
  im_rgb = Image.fromarray(im_bgr[..., ::-1]) # Convert BGR to RGB
24
- return im_rgb
25
 
26
- def file_uploader_cb(uploaded_file, font_size, line_width):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  image = Image.open(uploaded_file).convert("RGB")
28
  col1, col2 = st.columns(2)
29
  with col1:
30
  # Display Uploaded image
31
  st.image(image, caption='Uploaded Image', use_column_width=True)
32
  # Perform inference
33
- annotated_img = predict(model, image, font_size, line_width)
34
  with col2:
35
  # Display the prediction
36
  st.image(annotated_img, caption='Prediction', use_column_width=True)
@@ -38,11 +69,18 @@ def file_uploader_cb(uploaded_file, font_size, line_width):
38
  imbuffer = BytesIO()
39
  annotated_img.save(imbuffer, format="JPEG")
40
  st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="upload")
 
 
 
 
 
 
 
41
 
42
- def image_capture_cb(capture, font_size, line_width, col):
43
  image = Image.open(capture).convert("RGB")
44
  # Perform inference
45
- annotated_img = predict(model, image, font_size, line_width)
46
  with col:
47
  # Display the prediction
48
  st.image(annotated_img, caption='Prediction', use_column_width=True)
@@ -51,36 +89,9 @@ def image_capture_cb(capture, font_size, line_width, col):
51
  annotated_img.save(imbuffer, format="JPEG")
52
  st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="capture")
53
 
54
- if __name__ == "__main__":
55
- # set page configurations and display/annotation options
56
- st.set_page_config(
57
- page_title="Circuit Sketch Recognizer",
58
- layout="wide"
59
- )
60
- st.title("Circuit Sketch Recognition")
61
- with st.sidebar:
62
- font_size = st.slider(label="Font Size", min_value=6, max_value=64, step=1, value=24)
63
- line_width = st.slider(label="Bounding Box Line Thickness", min_value=1, max_value=8, step=1, value=3)
64
-
65
- model = load_model()
66
-
67
- # user specifies to take/upload picture, view examples
68
- tabs = st.tabs(["Capture Picture", "Upload Your Image", "Show Examples"])
69
- with tabs[0]:
70
- # File uploader allows user to add their own image
71
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
72
- if uploaded_file is not None:
73
- file_uploader_cb(uploaded_file, font_size, line_width)
74
- with tabs[1]:
75
- # Camera Input allows user to take a picture
76
- col1, col2 = st.columns(2)
77
- with col1:
78
- capture = st.camera_input("Take a picture with Camera")
79
- if capture is not None:
80
- image_capture_cb(capture, font_size, line_width, col2)
81
- with tabs[2]:
82
- col1, col2 = st.columns(2)
83
- with col1:
84
- st.image('example1.jpg', use_column_width=True, caption='Example 1')
85
- with col2:
86
- st.image('example2.jpg', use_column_width=True, caption='Example 2')
 
1
  import streamlit as st
2
  from PIL import Image
 
3
  from ultralytics import YOLO # Make sure this import works in your Hugging Face environment
4
  from io import BytesIO
5
+ import numpy as np
6
+ import pandas as pd
7
+ from transformers import VisionEncoderDecoderModel, TrOCRProcessor
8
 
9
+ @st.cache_resource
10
+ def load_ocr_model():
11
+ """
12
+ Load and cache the ocr model and processor
13
+ """
14
+ model = VisionEncoderDecoderModel.from_pretrained('edesaras/TROCR_finetuned_on_CSTA', cache_dir='./models/TrOCR')
15
+ processor = TrOCRProcessor.from_pretrained("edesaras/TROCR_finetuned_on_CSTA", cache_dir='./models/TrOCR')
16
+ return model, processor
17
 
18
  @st.cache_resource
19
  def load_model():
20
  """
21
  Load and cache the model
22
  """
23
+ model = YOLO('./models/YOLO/weights.pt')
24
  return model
25
 
26
  def predict(model, image, font_size, line_width):
 
31
  r = results[0]
32
  im_bgr = r.plot(conf=False, pil=True, font_size=font_size, line_width=line_width) # Returns a PIL image if pil=True
33
  im_rgb = Image.fromarray(im_bgr[..., ::-1]) # Convert BGR to RGB
34
+ return im_rgb, r
35
 
36
+ def extract_text_patches(result, image):
37
+ image = np.array(image)
38
+ text_bboxes = []
39
+ for i, label in enumerate([result.names[id.item()] for id in result.boxes.cls]):
40
+ if label == 'text':
41
+ bbox = result.boxes.xyxy[i]
42
+ text_bboxes.append([round(i.item()) for i in bbox])
43
+ crops = []
44
+ for box in text_bboxes:
45
+ xmin, ymin, xmax, ymax = box
46
+ crop_img = image[ymin:ymax, xmin:xmax]
47
+ crops.append(crop_img)
48
+ return crops, text_bboxes
49
+
50
+ def ocr_predict(model, processor, crops):
51
+ pixel_values = processor(crops, return_tensors="pt").pixel_values
52
+ # Generate text with TrOCR
53
+ generated_ids = model.generate(pixel_values)
54
+ texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
55
+ return texts
56
+
57
+ def file_uploader_cb(model, ocr_model, ocr_processor, uploaded_file, font_size, line_width):
58
  image = Image.open(uploaded_file).convert("RGB")
59
  col1, col2 = st.columns(2)
60
  with col1:
61
  # Display Uploaded image
62
  st.image(image, caption='Uploaded Image', use_column_width=True)
63
  # Perform inference
64
+ annotated_img, result = predict(model, image, font_size, line_width)
65
  with col2:
66
  # Display the prediction
67
  st.image(annotated_img, caption='Prediction', use_column_width=True)
 
69
  imbuffer = BytesIO()
70
  annotated_img.save(imbuffer, format="JPEG")
71
  st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="upload")
72
+
73
+ st.subheader('Transcription')
74
+ crops, text_bboxes = extract_text_patches(result, image)
75
+ texts = ocr_predict(ocr_model, ocr_processor, crops)
76
+ transcription_df = pd.DataFrame(zip(texts, *np.array(text_bboxes).T, [st.image(crop) for crop in crops]),
77
+ columns=['Transcription', 'xmin', 'ymin', 'xmax', 'ymax', 'Image'])
78
+ st.dataframe(transcription_df)
79
 
80
+ def image_capture_cb(model, ocr_model, ocr_processor, capture, font_size, line_width, col):
81
  image = Image.open(capture).convert("RGB")
82
  # Perform inference
83
+ annotated_img, result = predict(model, image, font_size, line_width)
84
  with col:
85
  # Display the prediction
86
  st.image(annotated_img, caption='Prediction', use_column_width=True)
 
89
  annotated_img.save(imbuffer, format="JPEG")
90
  st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="capture")
91
 
92
+ st.subheader('Transcription')
93
+ crops, text_bboxes = extract_text_patches(result, image)
94
+ texts = ocr_predict(ocr_model, ocr_processor, crops)
95
+ transcription_df = pd.DataFrame(zip(texts, *np.array(text_bboxes).T),
96
+ columns=['Transcription', 'xmin', 'ymin', 'xmax', 'ymax'])
97
+ st.dataframe(transcription_df)