edesaras commited on
Commit
49c2c74
1 Parent(s): 8ebf841

New Features: Tab View, Annotated Image Download, Take pictures as well as upload images

Browse files
Files changed (2) hide show
  1. .streamlit/config.toml +3 -0
  2. app.py +70 -21
.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [server]
2
+
3
+ maxUploadSize = 20
app.py CHANGED
@@ -2,36 +2,85 @@ import streamlit as st
2
  from PIL import Image
3
  import numpy as np
4
  from ultralytics import YOLO # Make sure this import works in your Hugging Face environment
 
 
5
 
6
- # Load the model
7
  @st.cache_resource
8
  def load_model():
 
 
 
9
  model = YOLO("weights.pt") # Adjust path if needed
10
  return model
11
 
12
- model = load_model()
13
-
14
- st.title("Circuit Sketch Recognition")
15
-
16
- # File uploader allows user to add their own image
17
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
 
 
18
 
19
- if uploaded_file is not None:
20
  image = Image.open(uploaded_file).convert("RGB")
21
- st.image(image, caption='Uploaded Image', use_column_width=True)
22
- st.write("")
23
- st.write("Detecting...")
 
 
 
 
 
 
 
 
 
 
24
 
 
 
25
  # Perform inference
26
- results = model.predict(image)
27
- r = results[0]
28
- im_bgr = r.plot(conf=False, pil=True, font_size=48, line_width=3) # Returns a PIL image if pil=True
29
- im_rgb = Image.fromarray(im_bgr[..., ::-1]) # Convert BGR to RGB
 
 
 
 
30
 
31
- # Display the prediction
32
- st.image(im_rgb, caption='Prediction', use_column_width=True)
 
 
 
 
 
 
 
 
33
 
34
- # Optionally, display pre-computed example images
35
- if st.checkbox('Show Example Results'):
36
- st.image('example1.jpg', use_column_width=True, caption='Example 1')
37
- st.image('example2.jpg', use_column_width=True, caption='Example 2')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from PIL import Image
3
  import numpy as np
4
  from ultralytics import YOLO # Make sure this import works in your Hugging Face environment
5
+ from io import BytesIO
6
+
7
 
 
8
  @st.cache_resource
9
  def load_model():
10
+ """
11
+ Load and cache the model
12
+ """
13
  model = YOLO("weights.pt") # Adjust path if needed
14
  return model
15
 
16
+ def predict(model, image, font_size, line_width):
17
+ """
18
+ Run inference and return annotated image
19
+ """
20
+ results = model.predict(image)
21
+ r = results[0]
22
+ im_bgr = r.plot(conf=False, pil=True, font_size=font_size, line_width=line_width) # Returns a PIL image if pil=True
23
+ im_rgb = Image.fromarray(im_bgr[..., ::-1]) # Convert BGR to RGB
24
+ return im_rgb
25
 
26
+ def file_uploader_cb(uploaded_file, font_size, line_width):
27
  image = Image.open(uploaded_file).convert("RGB")
28
+ col1, col2 = st.columns(2)
29
+ with col1:
30
+ # Display Uploaded image
31
+ st.image(image, caption='Uploaded Image', use_column_width=True)
32
+ # Perform inference
33
+ annotated_img = predict(model, image, font_size, line_width)
34
+ with col2:
35
+ # Display the prediction
36
+ st.image(annotated_img, caption='Prediction', use_column_width=True)
37
+ # write image to memory buffer for download
38
+ imbuffer = BytesIO()
39
+ annotated_img.save(imbuffer, format="JPEG")
40
+ st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="upload")
41
 
42
+ def image_capture_cb(capture, font_size, line_width, col):
43
+ image = Image.open(capture).convert("RGB")
44
  # Perform inference
45
+ annotated_img = predict(model, image, font_size, line_width)
46
+ with col:
47
+ # Display the prediction
48
+ st.image(annotated_img, caption='Prediction', use_column_width=True)
49
+ # write image to memory buffer for download
50
+ imbuffer = BytesIO()
51
+ annotated_img.save(imbuffer, format="JPEG")
52
+ st.download_button("Download Annotated Image", data=imbuffer, file_name="Annotated_Sketch.jpeg", mime="image/jpeg", key="capture")
53
 
54
+ if __name__ == "__main__":
55
+ # set page configurations and display/annotation options
56
+ st.set_page_config(
57
+ page_title="Circuit Sketch Recognizer",
58
+ layout="wide"
59
+ )
60
+ st.title("Circuit Sketch Recognition")
61
+ with st.sidebar:
62
+ font_size = st.slider(label="Font Size", min_value=6, max_value=64, step=1, value=24)
63
+ line_width = st.slider(label="Bounding Box Line Thickness", min_value=1, max_value=8, step=1, value=3)
64
 
65
+ model = load_model()
66
+
67
+ # user specifies to take/upload picture, view examples
68
+ tabs = st.tabs(["Capture Picture", "Upload Your Image", "Show Examples"])
69
+ with tabs[0]:
70
+ # File uploader allows user to add their own image
71
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
72
+ if uploaded_file is not None:
73
+ file_uploader_cb(uploaded_file, font_size, line_width)
74
+ with tabs[1]:
75
+ # Camera Input allows user to take a picture
76
+ col1, col2 = st.columns(2)
77
+ with col1:
78
+ capture = st.camera_input("Take a picture with Camera")
79
+ if capture is not None:
80
+ image_capture_cb(capture, font_size, line_width, col2)
81
+ with tabs[2]:
82
+ col1, col2 = st.columns(2)
83
+ with col1:
84
+ st.image('example1.jpg', use_column_width=True, caption='Example 1')
85
+ with col2:
86
+ st.image('example2.jpg', use_column_width=True, caption='Example 2')