Martlgap commited on
Commit
ec0a6ec
1 Parent(s): 678e0d3

fixed upload issue

Browse files
Files changed (2) hide show
  1. app.py +21 -19
  2. tools/utils.py +3 -3
app.py CHANGED
@@ -12,7 +12,7 @@ from streamlit_toggle import st_toggle_switch
12
  import pandas as pd
13
  from tools.nametypes import Stats, Detection
14
  from pathlib import Path
15
- from tools.utils import get_ice_servers, download_file, display_match, rgb, format_list
16
  from tools.face_recognition import (
17
  detect_faces,
18
  align_faces,
@@ -22,9 +22,6 @@ from tools.face_recognition import (
22
  process_gallery,
23
  )
24
 
25
- # TODO Error Handling!
26
-
27
-
28
  # Set logging level to error (To avoid getting spammed by queue warnings etc.)
29
  logger = logging.getLogger(__name__)
30
  logging.basicConfig(level=logging.ERROR)
@@ -54,14 +51,6 @@ with st.sidebar:
54
  track_color=rgb(50, 50, 50),
55
  )
56
 
57
- upscale = st_toggle_switch(
58
- "Upscale",
59
- key="upscale",
60
- default_value=True,
61
- active_color=rgb(255, 75, 75),
62
- track_color=rgb(50, 50, 50),
63
- )
64
-
65
  st.markdown("## Webcam & Stream")
66
  resolution = st.selectbox(
67
  "Webcam Resolution",
@@ -128,6 +117,19 @@ else:
128
  )
129
  st.session_state[cache_key] = face_detection_model
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  stats_queue: "queue.Queue[Stats]" = queue.Queue()
132
  detections_queue: "queue.Queue[List[Detection]]" = queue.Queue()
133
 
@@ -173,7 +175,7 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
173
 
174
  # Draw detections
175
  start = time.time()
176
- frame = draw_detections(frame, detections, upscale=upscale)
177
  stats = stats._replace(drawing=(time.time() - start) * 1000)
178
 
179
  # Convert frame back to av.VideoFrame
@@ -183,8 +185,8 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
183
  stats = stats._replace(fps=1 / (time.time() - frame_start))
184
 
185
  # Send data to other thread
186
- detections_queue.put(detections)
187
- stats_queue.put(stats)
188
 
189
  return frame
190
 
@@ -197,7 +199,7 @@ gallery = st.sidebar.file_uploader(
197
  "Upload images to gallery", type=["png", "jpg", "jpeg"], accept_multiple_files=True
198
  )
199
  if gallery:
200
- gallery = process_gallery(gallery, face_detection_model, face_recognition_model_gal)
201
  st.sidebar.markdown("**Gallery Images**")
202
  st.sidebar.image(
203
  [identity.image for identity in gallery],
@@ -241,17 +243,17 @@ detections = st.empty()
241
  if ctx.state.playing:
242
  while True:
243
  # Get stats
244
- stats_dataframe = pd.DataFrame([stats_queue.get()])
245
 
246
  # Write stats to streamlit
247
  stats.dataframe(stats_dataframe.style.format(thousands=" ", precision=2))
248
 
249
  # Get detections
250
- detections_data = detections_queue.get()
251
  detections_dataframe = (
252
  pd.DataFrame(detections_data)
253
  .drop(columns=["face", "face_match"], errors="ignore")
254
- .applymap(lambda x: (format_list(x)))
255
  )
256
 
257
  # Write detections to streamlit
 
12
  import pandas as pd
13
  from tools.nametypes import Stats, Detection
14
  from pathlib import Path
15
+ from tools.utils import get_ice_servers, download_file, display_match, rgb, format_dflist
16
  from tools.face_recognition import (
17
  detect_faces,
18
  align_faces,
 
22
  process_gallery,
23
  )
24
 
 
 
 
25
  # Set logging level to error (To avoid getting spammed by queue warnings etc.)
26
  logger = logging.getLogger(__name__)
27
  logging.basicConfig(level=logging.ERROR)
 
51
  track_color=rgb(50, 50, 50),
52
  )
53
 
 
 
 
 
 
 
 
 
54
  st.markdown("## Webcam & Stream")
55
  resolution = st.selectbox(
56
  "Webcam Resolution",
 
117
  )
118
  st.session_state[cache_key] = face_detection_model
119
 
120
+ # Session-specific caching of the face detection model
121
+ cache_key = "face_detection_model_gal"
122
+ if cache_key in st.session_state:
123
+ face_detection_model_gal = st.session_state[cache_key]
124
+ else:
125
+ face_detection_model_gal = mp.solutions.face_mesh.FaceMesh(
126
+ refine_landmarks=True,
127
+ min_detection_confidence=detection_confidence,
128
+ min_tracking_confidence=tracking_confidence,
129
+ max_num_faces=max_faces,
130
+ )
131
+ st.session_state[cache_key] = face_detection_model_gal
132
+
133
  stats_queue: "queue.Queue[Stats]" = queue.Queue()
134
  detections_queue: "queue.Queue[List[Detection]]" = queue.Queue()
135
 
 
175
 
176
  # Draw detections
177
  start = time.time()
178
+ frame = draw_detections(frame, detections)
179
  stats = stats._replace(drawing=(time.time() - start) * 1000)
180
 
181
  # Convert frame back to av.VideoFrame
 
185
  stats = stats._replace(fps=1 / (time.time() - frame_start))
186
 
187
  # Send data to other thread
188
+ detections_queue.put_nowait(detections)
189
+ stats_queue.put_nowait(stats)
190
 
191
  return frame
192
 
 
199
  "Upload images to gallery", type=["png", "jpg", "jpeg"], accept_multiple_files=True
200
  )
201
  if gallery:
202
+ gallery = process_gallery(gallery, face_detection_model_gal, face_recognition_model_gal)
203
  st.sidebar.markdown("**Gallery Images**")
204
  st.sidebar.image(
205
  [identity.image for identity in gallery],
 
243
  if ctx.state.playing:
244
  while True:
245
  # Get stats
246
+ stats_dataframe = pd.DataFrame([stats_queue.get(timeout=10)])
247
 
248
  # Write stats to streamlit
249
  stats.dataframe(stats_dataframe.style.format(thousands=" ", precision=2))
250
 
251
  # Get detections
252
+ detections_data = detections_queue.get(timeout=10)
253
  detections_dataframe = (
254
  pd.DataFrame(detections_data)
255
  .drop(columns=["face", "face_match"], errors="ignore")
256
+ .applymap(lambda x: (format_dflist(x)))
257
  )
258
 
259
  # Write detections to streamlit
tools/utils.py CHANGED
@@ -143,11 +143,11 @@ def download_file(url, model_path: Path, file_hash=None):
143
 
144
 
145
  # Function to format floats within a list
146
- def format_list(val):
147
  if isinstance(val, list):
148
- return [format_list(num) for num in val]
149
  if isinstance(val, np.ndarray):
150
- return np.asarray([format_list(num) for num in val])
151
  if isinstance(val, np.float32):
152
  return f"{val:.2f}"
153
  if isinstance(val, float):
 
143
 
144
 
145
  # Function to format floats within a list
146
+ def format_dflist(val):
147
  if isinstance(val, list):
148
+ return [format_dflist(num) for num in val]
149
  if isinstance(val, np.ndarray):
150
+ return np.asarray([format_dflist(num) for num in val])
151
  if isinstance(val, np.float32):
152
  return f"{val:.2f}"
153
  if isinstance(val, float):