Martlgap commited on
Commit
cb74f9c
1 Parent(s): bffe7b3

running version, only with bit delay

Browse files
Files changed (4) hide show
  1. app.py +73 -55
  2. tools/face_recognition.py +11 -6
  3. tools/nametypes.py +11 -11
  4. tools/utils.py +11 -8
app.py CHANGED
@@ -12,7 +12,7 @@ from streamlit_toggle import st_toggle_switch
12
  import pandas as pd
13
  from tools.nametypes import Stats, Detection
14
  from pathlib import Path
15
- from tools.utils import get_ice_servers, download_file, display_match, rgb
16
  from tools.face_recognition import (
17
  detect_faces,
18
  align_faces,
@@ -54,12 +54,19 @@ with st.sidebar:
54
  track_color=rgb(50, 50, 50),
55
  )
56
 
57
- st.markdown("## Webcam")
58
  resolution = st.selectbox(
59
  "Webcam Resolution",
60
  [(1920, 1080), (1280, 720), (640, 360)],
61
  index=2,
62
  )
 
 
 
 
 
 
 
63
  st.markdown("## Face Detection")
64
  max_faces = st.number_input("Maximum Number of Faces", value=2, min_value=1)
65
  detection_confidence = st.slider(
@@ -68,17 +75,13 @@ with st.sidebar:
68
  tracking_confidence = st.slider(
69
  "Min Tracking Confidence", min_value=0.0, max_value=1.0, value=0.9
70
  )
71
- on_draw = st_toggle_switch(
72
- "Show Drawings",
73
- key="show_drawings",
74
- default_value=True,
75
- active_color=rgb(255, 75, 75),
76
- track_color=rgb(100, 100, 100),
77
- )
78
  st.markdown("## Face Recognition")
79
  similarity_threshold = st.slider(
80
  "Similarity Threshold", min_value=0.0, max_value=2.0, value=0.67
81
  )
 
 
 
82
 
83
  download_file(
84
  MODEL_URL,
@@ -94,6 +97,16 @@ else:
94
  face_recognition_model = tflite.Interpreter(model_path=MODEL_LOCAL_PATH.as_posix())
95
  st.session_state[cache_key] = face_recognition_model
96
 
 
 
 
 
 
 
 
 
 
 
97
  # Session-specific caching of the face detection model
98
  cache_key = "face_detection_model"
99
  if cache_key in st.session_state:
@@ -112,58 +125,58 @@ detections_queue: "queue.Queue[List[Detection]]" = queue.Queue()
112
 
113
 
114
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
115
- detections = None
 
 
 
 
 
 
116
  frame_start = time.time()
117
 
118
  # Convert frame to numpy array
119
  frame = frame.to_ndarray(format="rgb24")
120
 
121
- # Get frame resolution
122
  resolution = frame.shape
 
123
 
124
- start = time.time()
125
  if face_rec_on:
 
 
126
  detections = detect_faces(frame, face_detection_model)
127
- time_detection = (time.time() - start) * 1000
 
128
 
129
- start = time.time()
130
- if face_rec_on:
131
  detections = align_faces(frame, detections)
132
- time_normalization = (time.time() - start) * 1000
133
 
134
- start = time.time()
135
- if face_rec_on:
136
  detections = inference(detections, face_recognition_model)
137
- time_inference = (time.time() - start) * 1000
138
 
139
- start = time.time()
140
- if face_rec_on:
141
  detections = recognize_faces(detections, gallery, similarity_threshold)
142
- time_recognition = (time.time() - start) * 1000
143
 
144
- start = time.time()
145
- if face_rec_on and on_draw:
146
  frame = draw_detections(frame, detections)
147
- time_drawing = (time.time() - start) * 1000
148
 
149
  # Convert frame back to av.VideoFrame
150
  frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
151
 
152
- # Put detections, stats and timings into queues (to be accessible by other thread)
153
- if face_rec_on:
154
- detections_queue.put(detections)
155
- stats_queue.put(
156
- Stats(
157
- fps=1 / (time.time() - frame_start),
158
- resolution=resolution,
159
- num_faces=len(detections) if detections else 0,
160
- detection=time_detection,
161
- normalization=time_normalization,
162
- inference=time_inference,
163
- recognition=time_recognition,
164
- drawing=time_drawing,
165
- )
166
- )
167
 
168
  return frame
169
 
@@ -176,7 +189,7 @@ gallery = st.sidebar.file_uploader(
176
  "Upload images to gallery", type=["png", "jpg", "jpeg"], accept_multiple_files=True
177
  )
178
  if gallery:
179
- gallery = process_gallery(gallery, face_detection_model, face_recognition_model)
180
  st.sidebar.markdown("**Gallery Images**")
181
  st.sidebar.image(
182
  [identity.image for identity in gallery],
@@ -190,7 +203,7 @@ stats = st.empty()
190
  ctx = webrtc_streamer(
191
  key="FaceIDAppDemo",
192
  mode=WebRtcMode.SENDRECV,
193
- rtc_configuration={"iceServers": get_ice_servers("twilio")},
194
  video_frame_callback=video_frame_callback,
195
  media_stream_constraints={
196
  "video": {
@@ -198,16 +211,18 @@ ctx = webrtc_streamer(
198
  "min": resolution[0],
199
  "ideal": resolution[0],
200
  "max": resolution[0],
201
- }
 
 
 
 
 
202
  },
203
  "audio": False,
204
  },
205
- async_processing=False, # WHAT IS THIS?
206
  )
207
 
208
- st.markdown("**Timings [ms]**")
209
- timings = st.empty()
210
-
211
  st.markdown("**Identified Faces**")
212
  identified_faces = st.empty()
213
 
@@ -217,19 +232,24 @@ detections = st.empty()
217
  # Display Live Stats
218
  if ctx.state.playing:
219
  while True:
 
220
  stats_dataframe = pd.DataFrame([stats_queue.get()])
 
 
221
  stats.dataframe(stats_dataframe.style.format(thousands=" ", precision=2))
222
 
 
223
  detections_data = detections_queue.get()
224
- detections_dataframe = pd.DataFrame(detections_data).drop(
225
- columns=["face", "face_match"], errors="ignore"
 
 
226
  )
227
- # Apply formatting to DataFrame
228
- # print(detections_dataframe.columns)
229
- # detections_dataframe["embedding"] = detections_dataframe["embedding"].embedding.applymap(format_floats)
230
 
 
231
  detections.dataframe(detections_dataframe)
232
 
 
233
  identified_faces.image(
234
  [display_match(d) for d in detections_data if d.name is not None],
235
  caption=[
@@ -238,6 +258,4 @@ if ctx.state.playing:
238
  if d.name is not None
239
  ],
240
  width=112,
241
- ) # TODO formatting
242
-
243
- # time.sleep(1)
 
12
  import pandas as pd
13
  from tools.nametypes import Stats, Detection
14
  from pathlib import Path
15
+ from tools.utils import get_ice_servers, download_file, display_match, rgb, format_list
16
  from tools.face_recognition import (
17
  detect_faces,
18
  align_faces,
 
54
  track_color=rgb(50, 50, 50),
55
  )
56
 
57
+ st.markdown("## Webcam & Stream")
58
  resolution = st.selectbox(
59
  "Webcam Resolution",
60
  [(1920, 1080), (1280, 720), (640, 360)],
61
  index=2,
62
  )
63
+ st.markdown("Note: To change the resolution, you have to restart the stream.")
64
+
65
+ ice_server = st.selectbox("ICE Server", ["twilio", "metered"], index=0)
66
+ st.markdown(
67
+ "Note: metered is a free server with limited bandwidth, and can take a while to connect. Twilio is a paid service and is payed by me, so please don't abuse it."
68
+ )
69
+
70
  st.markdown("## Face Detection")
71
  max_faces = st.number_input("Maximum Number of Faces", value=2, min_value=1)
72
  detection_confidence = st.slider(
 
75
  tracking_confidence = st.slider(
76
  "Min Tracking Confidence", min_value=0.0, max_value=1.0, value=0.9
77
  )
 
 
 
 
 
 
 
78
  st.markdown("## Face Recognition")
79
  similarity_threshold = st.slider(
80
  "Similarity Threshold", min_value=0.0, max_value=2.0, value=0.67
81
  )
82
+ st.markdown(
83
+ "This sets a maximum distance for the cosine similarity between the embeddings of the detected face and the gallery images. If the distance is below the threshold, the face is recognized as the gallery image with the lowest distance. If the distance is above the threshold, the face is not recognized."
84
+ )
85
 
86
  download_file(
87
  MODEL_URL,
 
97
  face_recognition_model = tflite.Interpreter(model_path=MODEL_LOCAL_PATH.as_posix())
98
  st.session_state[cache_key] = face_recognition_model
99
 
100
+ # Session-specific caching of the face recognition model
101
+ cache_key = "face_id_model_gal"
102
+ if cache_key in st.session_state:
103
+ face_recognition_model_gal = st.session_state[cache_key]
104
+ else:
105
+ face_recognition_model_gal = tflite.Interpreter(
106
+ model_path=MODEL_LOCAL_PATH.as_posix()
107
+ )
108
+ st.session_state[cache_key] = face_recognition_model_gal
109
+
110
  # Session-specific caching of the face detection model
111
  cache_key = "face_detection_model"
112
  if cache_key in st.session_state:
 
125
 
126
 
127
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
128
+ # Initialize detections
129
+ detections = []
130
+
131
+ # Initialize stats
132
+ stats = Stats()
133
+
134
+ # Start timer for FPS calculation
135
  frame_start = time.time()
136
 
137
  # Convert frame to numpy array
138
  frame = frame.to_ndarray(format="rgb24")
139
 
140
+ # Get frame resolution and add to stats
141
  resolution = frame.shape
142
+ stats = stats._replace(resolution=resolution)
143
 
 
144
  if face_rec_on:
145
+ # Run face detection
146
+ start = time.time()
147
  detections = detect_faces(frame, face_detection_model)
148
+ stats = stats._replace(num_faces=len(detections) if detections else 0)
149
+ stats = stats._replace(detection=(time.time() - start) * 1000)
150
 
151
+ # Run face alignment
152
+ start = time.time()
153
  detections = align_faces(frame, detections)
154
+ stats = stats._replace(alignment=(time.time() - start) * 1000)
155
 
156
+ # Run inference
157
+ start = time.time()
158
  detections = inference(detections, face_recognition_model)
159
+ stats = stats._replace(inference=(time.time() - start) * 1000)
160
 
161
+ # Run face recognition
162
+ start = time.time()
163
  detections = recognize_faces(detections, gallery, similarity_threshold)
164
+ stats = stats._replace(recognition=(time.time() - start) * 1000)
165
 
166
+ # Draw detections
167
+ start = time.time()
168
  frame = draw_detections(frame, detections)
169
+ stats = stats._replace(drawing=(time.time() - start) * 1000)
170
 
171
  # Convert frame back to av.VideoFrame
172
  frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
173
 
174
+ # Calculate FPS and add to stats
175
+ stats = stats._replace(fps=1 / (time.time() - frame_start))
176
+
177
+ # Send data to other thread
178
+ detections_queue.put(detections)
179
+ stats_queue.put(stats)
 
 
 
 
 
 
 
 
 
180
 
181
  return frame
182
 
 
189
  "Upload images to gallery", type=["png", "jpg", "jpeg"], accept_multiple_files=True
190
  )
191
  if gallery:
192
+ gallery = process_gallery(gallery, face_detection_model, face_recognition_model_gal)
193
  st.sidebar.markdown("**Gallery Images**")
194
  st.sidebar.image(
195
  [identity.image for identity in gallery],
 
203
  ctx = webrtc_streamer(
204
  key="FaceIDAppDemo",
205
  mode=WebRtcMode.SENDRECV,
206
+ rtc_configuration={"iceServers": get_ice_servers(name=ice_server)},
207
  video_frame_callback=video_frame_callback,
208
  media_stream_constraints={
209
  "video": {
 
211
  "min": resolution[0],
212
  "ideal": resolution[0],
213
  "max": resolution[0],
214
+ },
215
+ "height": {
216
+ "min": resolution[1],
217
+ "ideal": resolution[1],
218
+ "max": resolution[1],
219
+ },
220
  },
221
  "audio": False,
222
  },
223
+ async_processing=True,
224
  )
225
 
 
 
 
226
  st.markdown("**Identified Faces**")
227
  identified_faces = st.empty()
228
 
 
232
  # Display Live Stats
233
  if ctx.state.playing:
234
  while True:
235
+ # Get stats
236
  stats_dataframe = pd.DataFrame([stats_queue.get()])
237
+
238
+ # Write stats to streamlit
239
  stats.dataframe(stats_dataframe.style.format(thousands=" ", precision=2))
240
 
241
+ # Get detections
242
  detections_data = detections_queue.get()
243
+ detections_dataframe = (
244
+ pd.DataFrame(detections_data)
245
+ .drop(columns=["face", "face_match"], errors="ignore")
246
+ .applymap(lambda x: (format_list(x)))
247
  )
 
 
 
248
 
249
+ # Write detections to streamlit
250
  detections.dataframe(detections_dataframe)
251
 
252
+ # Write identified faces to streamlit
253
  identified_faces.image(
254
  [display_match(d) for d in detections_data if d.name is not None],
255
  caption=[
 
258
  if d.name is not None
259
  ],
260
  width=112,
261
+ )
 
 
tools/face_recognition.py CHANGED
@@ -62,6 +62,7 @@ def align_faces(img, detections):
62
  )
63
  return updated_detections
64
 
 
65
  # TODO Error when uploading image while running!
66
  def inference(detections, model):
67
  updated_detections = []
@@ -78,17 +79,16 @@ def inference(detections, model):
78
  ]
79
 
80
  for idx, detection in enumerate(detections):
81
- updated_detections.append(detection._replace(emdedding=embs[idx]))
82
  return updated_detections
83
 
84
 
85
  def recognize_faces(detections, gallery, thresh=0.67):
86
-
87
  if len(gallery) == 0 or len(detections) == 0:
88
  return detections
89
 
90
  gallery_embs = np.asarray([identity.embedding for identity in gallery])
91
- detection_embs = np.asarray([detection.emdedding for detection in detections])
92
 
93
  cos_distances = cosine_distances(detection_embs, gallery_embs)
94
 
@@ -103,8 +103,13 @@ def recognize_faces(detections, gallery, thresh=0.67):
103
  pred = idx_min
104
  updated_detections.append(
105
  detection._replace(
106
- name=gallery[pred].name.split(".jpg")[0].split(".png")[0].split(".jpeg")[0] if pred is not None else None,
107
- emdedding_match=gallery[pred].embedding if pred is not None else None,
 
 
 
 
 
108
  face_match=gallery[pred].image if pred is not None else None,
109
  distance=dist,
110
  )
@@ -135,7 +140,7 @@ def process_gallery(files, face_detection_model, face_recognition_model):
135
  gallery.append(
136
  Identity(
137
  name=file.name,
138
- embedding=detections[0].emdedding,
139
  image=detections[0].face,
140
  )
141
  )
 
62
  )
63
  return updated_detections
64
 
65
+
66
  # TODO Error when uploading image while running!
67
  def inference(detections, model):
68
  updated_detections = []
 
79
  ]
80
 
81
  for idx, detection in enumerate(detections):
82
+ updated_detections.append(detection._replace(embedding=embs[idx]))
83
  return updated_detections
84
 
85
 
86
  def recognize_faces(detections, gallery, thresh=0.67):
 
87
  if len(gallery) == 0 or len(detections) == 0:
88
  return detections
89
 
90
  gallery_embs = np.asarray([identity.embedding for identity in gallery])
91
+ detection_embs = np.asarray([detection.embedding for detection in detections])
92
 
93
  cos_distances = cosine_distances(detection_embs, gallery_embs)
94
 
 
103
  pred = idx_min
104
  updated_detections.append(
105
  detection._replace(
106
+ name=gallery[pred]
107
+ .name.split(".jpg")[0]
108
+ .split(".png")[0]
109
+ .split(".jpeg")[0]
110
+ if pred is not None
111
+ else None,
112
+ embedding_match=gallery[pred].embedding if pred is not None else None,
113
  face_match=gallery[pred].image if pred is not None else None,
114
  distance=dist,
115
  )
 
140
  gallery.append(
141
  Identity(
142
  name=file.name,
143
+ embedding=detections[0].embedding,
144
  image=detections[0].face,
145
  )
146
  )
tools/nametypes.py CHANGED
@@ -7,22 +7,22 @@ class Detection(NamedTuple):
7
  landmarks: List[List[int]]
8
  name: str = None
9
  face: np.ndarray = None
10
- emdedding: np.ndarray = None
11
- emdedding_match: np.ndarray = None
12
  face_match: np.ndarray = None
13
  distance: float = None
14
 
15
 
16
  class Stats(NamedTuple):
17
- fps: float
18
- resolution: List[int]
19
- num_faces: int
20
- detection: float
21
- normalization: float
22
- inference: float
23
- recognition: float
24
- drawing: float
25
-
26
 
27
  class Identity(NamedTuple):
28
  name: str
 
7
  landmarks: List[List[int]]
8
  name: str = None
9
  face: np.ndarray = None
10
+ embedding: np.ndarray = None
11
+ embedding_match: np.ndarray = None
12
  face_match: np.ndarray = None
13
  distance: float = None
14
 
15
 
16
  class Stats(NamedTuple):
17
+ fps: float = 0
18
+ resolution: List[int] = [None, None, None]
19
+ num_faces: int = 0
20
+ detection: float = None
21
+ alignment: float = None
22
+ inference: float = None
23
+ recognition: float = None
24
+ drawing: float = None
25
+
26
 
27
  class Identity(NamedTuple):
28
  name: str
tools/utils.py CHANGED
@@ -110,7 +110,6 @@ def download_file(url, model_path: Path, file_hash=None):
110
  download = True
111
 
112
  if download:
113
-
114
  # These are handles to two visual elements to animate.
115
  weights_warning, progress_bar = None, None
116
  try:
@@ -144,14 +143,18 @@ def download_file(url, model_path: Path, file_hash=None):
144
 
145
 
146
  # Function to format floats within a list
147
- def format_floats(val):
148
- if isinstance(val, list):
149
- return [f"{num:.2f}" for num in val]
150
  if isinstance(val, np.ndarray):
151
- return np.asarray([f"{num:.2f}" for num in val])
 
 
 
 
152
  else:
153
  return val
154
-
155
 
156
  def display_match(d):
157
  im = np.concatenate([d.face, d.face_match])
@@ -163,10 +166,10 @@ def display_match(d):
163
  left=border_size,
164
  right=border_size,
165
  borderType=cv2.BORDER_CONSTANT,
166
- value=(255, 255, 120)
167
  )
168
  return border
169
 
170
 
171
  def rgb(r, g, b):
172
- return '#{:02x}{:02x}{:02x}'.format(r, g, b)
 
110
  download = True
111
 
112
  if download:
 
113
  # These are handles to two visual elements to animate.
114
  weights_warning, progress_bar = None, None
115
  try:
 
143
 
144
 
145
  # Function to format floats within a list
146
+ def format_list(val):
147
+ if isinstance(val, list):
148
+ return [format_list(num) for num in val]
149
  if isinstance(val, np.ndarray):
150
+ return np.asarray([format_list(num) for num in val])
151
+ if isinstance(val, np.float32):
152
+ return f"{val:.2f}"
153
+ if isinstance(val, float):
154
+ return f"{val:.2f}"
155
  else:
156
  return val
157
+
158
 
159
  def display_match(d):
160
  im = np.concatenate([d.face, d.face_match])
 
166
  left=border_size,
167
  right=border_size,
168
  borderType=cv2.BORDER_CONSTANT,
169
+ value=(255, 255, 120),
170
  )
171
  return border
172
 
173
 
174
  def rgb(r, g, b):
175
+ return "#{:02x}{:02x}{:02x}".format(r, g, b)