Spaces:
Sleeping
Sleeping
running version, only with bit delay
Browse files- app.py +73 -55
- tools/face_recognition.py +11 -6
- tools/nametypes.py +11 -11
- tools/utils.py +11 -8
app.py
CHANGED
@@ -12,7 +12,7 @@ from streamlit_toggle import st_toggle_switch
|
|
12 |
import pandas as pd
|
13 |
from tools.nametypes import Stats, Detection
|
14 |
from pathlib import Path
|
15 |
-
from tools.utils import get_ice_servers, download_file, display_match, rgb
|
16 |
from tools.face_recognition import (
|
17 |
detect_faces,
|
18 |
align_faces,
|
@@ -54,12 +54,19 @@ with st.sidebar:
|
|
54 |
track_color=rgb(50, 50, 50),
|
55 |
)
|
56 |
|
57 |
-
st.markdown("## Webcam")
|
58 |
resolution = st.selectbox(
|
59 |
"Webcam Resolution",
|
60 |
[(1920, 1080), (1280, 720), (640, 360)],
|
61 |
index=2,
|
62 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
st.markdown("## Face Detection")
|
64 |
max_faces = st.number_input("Maximum Number of Faces", value=2, min_value=1)
|
65 |
detection_confidence = st.slider(
|
@@ -68,17 +75,13 @@ with st.sidebar:
|
|
68 |
tracking_confidence = st.slider(
|
69 |
"Min Tracking Confidence", min_value=0.0, max_value=1.0, value=0.9
|
70 |
)
|
71 |
-
on_draw = st_toggle_switch(
|
72 |
-
"Show Drawings",
|
73 |
-
key="show_drawings",
|
74 |
-
default_value=True,
|
75 |
-
active_color=rgb(255, 75, 75),
|
76 |
-
track_color=rgb(100, 100, 100),
|
77 |
-
)
|
78 |
st.markdown("## Face Recognition")
|
79 |
similarity_threshold = st.slider(
|
80 |
"Similarity Threshold", min_value=0.0, max_value=2.0, value=0.67
|
81 |
)
|
|
|
|
|
|
|
82 |
|
83 |
download_file(
|
84 |
MODEL_URL,
|
@@ -94,6 +97,16 @@ else:
|
|
94 |
face_recognition_model = tflite.Interpreter(model_path=MODEL_LOCAL_PATH.as_posix())
|
95 |
st.session_state[cache_key] = face_recognition_model
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
# Session-specific caching of the face detection model
|
98 |
cache_key = "face_detection_model"
|
99 |
if cache_key in st.session_state:
|
@@ -112,58 +125,58 @@ detections_queue: "queue.Queue[List[Detection]]" = queue.Queue()
|
|
112 |
|
113 |
|
114 |
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
frame_start = time.time()
|
117 |
|
118 |
# Convert frame to numpy array
|
119 |
frame = frame.to_ndarray(format="rgb24")
|
120 |
|
121 |
-
# Get frame resolution
|
122 |
resolution = frame.shape
|
|
|
123 |
|
124 |
-
start = time.time()
|
125 |
if face_rec_on:
|
|
|
|
|
126 |
detections = detect_faces(frame, face_detection_model)
|
127 |
-
|
|
|
128 |
|
129 |
-
|
130 |
-
|
131 |
detections = align_faces(frame, detections)
|
132 |
-
|
133 |
|
134 |
-
|
135 |
-
|
136 |
detections = inference(detections, face_recognition_model)
|
137 |
-
|
138 |
|
139 |
-
|
140 |
-
|
141 |
detections = recognize_faces(detections, gallery, similarity_threshold)
|
142 |
-
|
143 |
|
144 |
-
|
145 |
-
|
146 |
frame = draw_detections(frame, detections)
|
147 |
-
|
148 |
|
149 |
# Convert frame back to av.VideoFrame
|
150 |
frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
|
151 |
|
152 |
-
#
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
resolution=resolution,
|
159 |
-
num_faces=len(detections) if detections else 0,
|
160 |
-
detection=time_detection,
|
161 |
-
normalization=time_normalization,
|
162 |
-
inference=time_inference,
|
163 |
-
recognition=time_recognition,
|
164 |
-
drawing=time_drawing,
|
165 |
-
)
|
166 |
-
)
|
167 |
|
168 |
return frame
|
169 |
|
@@ -176,7 +189,7 @@ gallery = st.sidebar.file_uploader(
|
|
176 |
"Upload images to gallery", type=["png", "jpg", "jpeg"], accept_multiple_files=True
|
177 |
)
|
178 |
if gallery:
|
179 |
-
gallery = process_gallery(gallery, face_detection_model,
|
180 |
st.sidebar.markdown("**Gallery Images**")
|
181 |
st.sidebar.image(
|
182 |
[identity.image for identity in gallery],
|
@@ -190,7 +203,7 @@ stats = st.empty()
|
|
190 |
ctx = webrtc_streamer(
|
191 |
key="FaceIDAppDemo",
|
192 |
mode=WebRtcMode.SENDRECV,
|
193 |
-
rtc_configuration={"iceServers": get_ice_servers(
|
194 |
video_frame_callback=video_frame_callback,
|
195 |
media_stream_constraints={
|
196 |
"video": {
|
@@ -198,16 +211,18 @@ ctx = webrtc_streamer(
|
|
198 |
"min": resolution[0],
|
199 |
"ideal": resolution[0],
|
200 |
"max": resolution[0],
|
201 |
-
}
|
|
|
|
|
|
|
|
|
|
|
202 |
},
|
203 |
"audio": False,
|
204 |
},
|
205 |
-
async_processing=
|
206 |
)
|
207 |
|
208 |
-
st.markdown("**Timings [ms]**")
|
209 |
-
timings = st.empty()
|
210 |
-
|
211 |
st.markdown("**Identified Faces**")
|
212 |
identified_faces = st.empty()
|
213 |
|
@@ -217,19 +232,24 @@ detections = st.empty()
|
|
217 |
# Display Live Stats
|
218 |
if ctx.state.playing:
|
219 |
while True:
|
|
|
220 |
stats_dataframe = pd.DataFrame([stats_queue.get()])
|
|
|
|
|
221 |
stats.dataframe(stats_dataframe.style.format(thousands=" ", precision=2))
|
222 |
|
|
|
223 |
detections_data = detections_queue.get()
|
224 |
-
detections_dataframe =
|
225 |
-
|
|
|
|
|
226 |
)
|
227 |
-
# Apply formatting to DataFrame
|
228 |
-
# print(detections_dataframe.columns)
|
229 |
-
# detections_dataframe["embedding"] = detections_dataframe["embedding"].embedding.applymap(format_floats)
|
230 |
|
|
|
231 |
detections.dataframe(detections_dataframe)
|
232 |
|
|
|
233 |
identified_faces.image(
|
234 |
[display_match(d) for d in detections_data if d.name is not None],
|
235 |
caption=[
|
@@ -238,6 +258,4 @@ if ctx.state.playing:
|
|
238 |
if d.name is not None
|
239 |
],
|
240 |
width=112,
|
241 |
-
)
|
242 |
-
|
243 |
-
# time.sleep(1)
|
|
|
12 |
import pandas as pd
|
13 |
from tools.nametypes import Stats, Detection
|
14 |
from pathlib import Path
|
15 |
+
from tools.utils import get_ice_servers, download_file, display_match, rgb, format_list
|
16 |
from tools.face_recognition import (
|
17 |
detect_faces,
|
18 |
align_faces,
|
|
|
54 |
track_color=rgb(50, 50, 50),
|
55 |
)
|
56 |
|
57 |
+
st.markdown("## Webcam & Stream")
|
58 |
resolution = st.selectbox(
|
59 |
"Webcam Resolution",
|
60 |
[(1920, 1080), (1280, 720), (640, 360)],
|
61 |
index=2,
|
62 |
)
|
63 |
+
st.markdown("Note: To change the resolution, you have to restart the stream.")
|
64 |
+
|
65 |
+
ice_server = st.selectbox("ICE Server", ["twilio", "metered"], index=0)
|
66 |
+
st.markdown(
|
67 |
+
"Note: metered is a free server with limited bandwidth, and can take a while to connect. Twilio is a paid service and is payed by me, so please don't abuse it."
|
68 |
+
)
|
69 |
+
|
70 |
st.markdown("## Face Detection")
|
71 |
max_faces = st.number_input("Maximum Number of Faces", value=2, min_value=1)
|
72 |
detection_confidence = st.slider(
|
|
|
75 |
tracking_confidence = st.slider(
|
76 |
"Min Tracking Confidence", min_value=0.0, max_value=1.0, value=0.9
|
77 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
st.markdown("## Face Recognition")
|
79 |
similarity_threshold = st.slider(
|
80 |
"Similarity Threshold", min_value=0.0, max_value=2.0, value=0.67
|
81 |
)
|
82 |
+
st.markdown(
|
83 |
+
"This sets a maximum distance for the cosine similarity between the embeddings of the detected face and the gallery images. If the distance is below the threshold, the face is recognized as the gallery image with the lowest distance. If the distance is above the threshold, the face is not recognized."
|
84 |
+
)
|
85 |
|
86 |
download_file(
|
87 |
MODEL_URL,
|
|
|
97 |
face_recognition_model = tflite.Interpreter(model_path=MODEL_LOCAL_PATH.as_posix())
|
98 |
st.session_state[cache_key] = face_recognition_model
|
99 |
|
100 |
+
# Session-specific caching of the face recognition model
|
101 |
+
cache_key = "face_id_model_gal"
|
102 |
+
if cache_key in st.session_state:
|
103 |
+
face_recognition_model_gal = st.session_state[cache_key]
|
104 |
+
else:
|
105 |
+
face_recognition_model_gal = tflite.Interpreter(
|
106 |
+
model_path=MODEL_LOCAL_PATH.as_posix()
|
107 |
+
)
|
108 |
+
st.session_state[cache_key] = face_recognition_model_gal
|
109 |
+
|
110 |
# Session-specific caching of the face detection model
|
111 |
cache_key = "face_detection_model"
|
112 |
if cache_key in st.session_state:
|
|
|
125 |
|
126 |
|
127 |
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
|
128 |
+
# Initialize detections
|
129 |
+
detections = []
|
130 |
+
|
131 |
+
# Initialize stats
|
132 |
+
stats = Stats()
|
133 |
+
|
134 |
+
# Start timer for FPS calculation
|
135 |
frame_start = time.time()
|
136 |
|
137 |
# Convert frame to numpy array
|
138 |
frame = frame.to_ndarray(format="rgb24")
|
139 |
|
140 |
+
# Get frame resolution and add to stats
|
141 |
resolution = frame.shape
|
142 |
+
stats = stats._replace(resolution=resolution)
|
143 |
|
|
|
144 |
if face_rec_on:
|
145 |
+
# Run face detection
|
146 |
+
start = time.time()
|
147 |
detections = detect_faces(frame, face_detection_model)
|
148 |
+
stats = stats._replace(num_faces=len(detections) if detections else 0)
|
149 |
+
stats = stats._replace(detection=(time.time() - start) * 1000)
|
150 |
|
151 |
+
# Run face alignment
|
152 |
+
start = time.time()
|
153 |
detections = align_faces(frame, detections)
|
154 |
+
stats = stats._replace(alignment=(time.time() - start) * 1000)
|
155 |
|
156 |
+
# Run inference
|
157 |
+
start = time.time()
|
158 |
detections = inference(detections, face_recognition_model)
|
159 |
+
stats = stats._replace(inference=(time.time() - start) * 1000)
|
160 |
|
161 |
+
# Run face recognition
|
162 |
+
start = time.time()
|
163 |
detections = recognize_faces(detections, gallery, similarity_threshold)
|
164 |
+
stats = stats._replace(recognition=(time.time() - start) * 1000)
|
165 |
|
166 |
+
# Draw detections
|
167 |
+
start = time.time()
|
168 |
frame = draw_detections(frame, detections)
|
169 |
+
stats = stats._replace(drawing=(time.time() - start) * 1000)
|
170 |
|
171 |
# Convert frame back to av.VideoFrame
|
172 |
frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
|
173 |
|
174 |
+
# Calculate FPS and add to stats
|
175 |
+
stats = stats._replace(fps=1 / (time.time() - frame_start))
|
176 |
+
|
177 |
+
# Send data to other thread
|
178 |
+
detections_queue.put(detections)
|
179 |
+
stats_queue.put(stats)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
return frame
|
182 |
|
|
|
189 |
"Upload images to gallery", type=["png", "jpg", "jpeg"], accept_multiple_files=True
|
190 |
)
|
191 |
if gallery:
|
192 |
+
gallery = process_gallery(gallery, face_detection_model, face_recognition_model_gal)
|
193 |
st.sidebar.markdown("**Gallery Images**")
|
194 |
st.sidebar.image(
|
195 |
[identity.image for identity in gallery],
|
|
|
203 |
ctx = webrtc_streamer(
|
204 |
key="FaceIDAppDemo",
|
205 |
mode=WebRtcMode.SENDRECV,
|
206 |
+
rtc_configuration={"iceServers": get_ice_servers(name=ice_server)},
|
207 |
video_frame_callback=video_frame_callback,
|
208 |
media_stream_constraints={
|
209 |
"video": {
|
|
|
211 |
"min": resolution[0],
|
212 |
"ideal": resolution[0],
|
213 |
"max": resolution[0],
|
214 |
+
},
|
215 |
+
"height": {
|
216 |
+
"min": resolution[1],
|
217 |
+
"ideal": resolution[1],
|
218 |
+
"max": resolution[1],
|
219 |
+
},
|
220 |
},
|
221 |
"audio": False,
|
222 |
},
|
223 |
+
async_processing=True,
|
224 |
)
|
225 |
|
|
|
|
|
|
|
226 |
st.markdown("**Identified Faces**")
|
227 |
identified_faces = st.empty()
|
228 |
|
|
|
232 |
# Display Live Stats
|
233 |
if ctx.state.playing:
|
234 |
while True:
|
235 |
+
# Get stats
|
236 |
stats_dataframe = pd.DataFrame([stats_queue.get()])
|
237 |
+
|
238 |
+
# Write stats to streamlit
|
239 |
stats.dataframe(stats_dataframe.style.format(thousands=" ", precision=2))
|
240 |
|
241 |
+
# Get detections
|
242 |
detections_data = detections_queue.get()
|
243 |
+
detections_dataframe = (
|
244 |
+
pd.DataFrame(detections_data)
|
245 |
+
.drop(columns=["face", "face_match"], errors="ignore")
|
246 |
+
.applymap(lambda x: (format_list(x)))
|
247 |
)
|
|
|
|
|
|
|
248 |
|
249 |
+
# Write detections to streamlit
|
250 |
detections.dataframe(detections_dataframe)
|
251 |
|
252 |
+
# Write identified faces to streamlit
|
253 |
identified_faces.image(
|
254 |
[display_match(d) for d in detections_data if d.name is not None],
|
255 |
caption=[
|
|
|
258 |
if d.name is not None
|
259 |
],
|
260 |
width=112,
|
261 |
+
)
|
|
|
|
tools/face_recognition.py
CHANGED
@@ -62,6 +62,7 @@ def align_faces(img, detections):
|
|
62 |
)
|
63 |
return updated_detections
|
64 |
|
|
|
65 |
# TODO Error when uploading image while running!
|
66 |
def inference(detections, model):
|
67 |
updated_detections = []
|
@@ -78,17 +79,16 @@ def inference(detections, model):
|
|
78 |
]
|
79 |
|
80 |
for idx, detection in enumerate(detections):
|
81 |
-
updated_detections.append(detection._replace(
|
82 |
return updated_detections
|
83 |
|
84 |
|
85 |
def recognize_faces(detections, gallery, thresh=0.67):
|
86 |
-
|
87 |
if len(gallery) == 0 or len(detections) == 0:
|
88 |
return detections
|
89 |
|
90 |
gallery_embs = np.asarray([identity.embedding for identity in gallery])
|
91 |
-
detection_embs = np.asarray([detection.
|
92 |
|
93 |
cos_distances = cosine_distances(detection_embs, gallery_embs)
|
94 |
|
@@ -103,8 +103,13 @@ def recognize_faces(detections, gallery, thresh=0.67):
|
|
103 |
pred = idx_min
|
104 |
updated_detections.append(
|
105 |
detection._replace(
|
106 |
-
name=gallery[pred]
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
108 |
face_match=gallery[pred].image if pred is not None else None,
|
109 |
distance=dist,
|
110 |
)
|
@@ -135,7 +140,7 @@ def process_gallery(files, face_detection_model, face_recognition_model):
|
|
135 |
gallery.append(
|
136 |
Identity(
|
137 |
name=file.name,
|
138 |
-
embedding=detections[0].
|
139 |
image=detections[0].face,
|
140 |
)
|
141 |
)
|
|
|
62 |
)
|
63 |
return updated_detections
|
64 |
|
65 |
+
|
66 |
# TODO Error when uploading image while running!
|
67 |
def inference(detections, model):
|
68 |
updated_detections = []
|
|
|
79 |
]
|
80 |
|
81 |
for idx, detection in enumerate(detections):
|
82 |
+
updated_detections.append(detection._replace(embedding=embs[idx]))
|
83 |
return updated_detections
|
84 |
|
85 |
|
86 |
def recognize_faces(detections, gallery, thresh=0.67):
|
|
|
87 |
if len(gallery) == 0 or len(detections) == 0:
|
88 |
return detections
|
89 |
|
90 |
gallery_embs = np.asarray([identity.embedding for identity in gallery])
|
91 |
+
detection_embs = np.asarray([detection.embedding for detection in detections])
|
92 |
|
93 |
cos_distances = cosine_distances(detection_embs, gallery_embs)
|
94 |
|
|
|
103 |
pred = idx_min
|
104 |
updated_detections.append(
|
105 |
detection._replace(
|
106 |
+
name=gallery[pred]
|
107 |
+
.name.split(".jpg")[0]
|
108 |
+
.split(".png")[0]
|
109 |
+
.split(".jpeg")[0]
|
110 |
+
if pred is not None
|
111 |
+
else None,
|
112 |
+
embedding_match=gallery[pred].embedding if pred is not None else None,
|
113 |
face_match=gallery[pred].image if pred is not None else None,
|
114 |
distance=dist,
|
115 |
)
|
|
|
140 |
gallery.append(
|
141 |
Identity(
|
142 |
name=file.name,
|
143 |
+
embedding=detections[0].embedding,
|
144 |
image=detections[0].face,
|
145 |
)
|
146 |
)
|
tools/nametypes.py
CHANGED
@@ -7,22 +7,22 @@ class Detection(NamedTuple):
|
|
7 |
landmarks: List[List[int]]
|
8 |
name: str = None
|
9 |
face: np.ndarray = None
|
10 |
-
|
11 |
-
|
12 |
face_match: np.ndarray = None
|
13 |
distance: float = None
|
14 |
|
15 |
|
16 |
class Stats(NamedTuple):
|
17 |
-
fps: float
|
18 |
-
resolution: List[int]
|
19 |
-
num_faces: int
|
20 |
-
detection: float
|
21 |
-
|
22 |
-
inference: float
|
23 |
-
recognition: float
|
24 |
-
drawing: float
|
25 |
-
|
26 |
|
27 |
class Identity(NamedTuple):
|
28 |
name: str
|
|
|
7 |
landmarks: List[List[int]]
|
8 |
name: str = None
|
9 |
face: np.ndarray = None
|
10 |
+
embedding: np.ndarray = None
|
11 |
+
embedding_match: np.ndarray = None
|
12 |
face_match: np.ndarray = None
|
13 |
distance: float = None
|
14 |
|
15 |
|
16 |
class Stats(NamedTuple):
|
17 |
+
fps: float = 0
|
18 |
+
resolution: List[int] = [None, None, None]
|
19 |
+
num_faces: int = 0
|
20 |
+
detection: float = None
|
21 |
+
alignment: float = None
|
22 |
+
inference: float = None
|
23 |
+
recognition: float = None
|
24 |
+
drawing: float = None
|
25 |
+
|
26 |
|
27 |
class Identity(NamedTuple):
|
28 |
name: str
|
tools/utils.py
CHANGED
@@ -110,7 +110,6 @@ def download_file(url, model_path: Path, file_hash=None):
|
|
110 |
download = True
|
111 |
|
112 |
if download:
|
113 |
-
|
114 |
# These are handles to two visual elements to animate.
|
115 |
weights_warning, progress_bar = None, None
|
116 |
try:
|
@@ -144,14 +143,18 @@ def download_file(url, model_path: Path, file_hash=None):
|
|
144 |
|
145 |
|
146 |
# Function to format floats within a list
|
147 |
-
def
|
148 |
-
if isinstance(val, list):
|
149 |
-
return [
|
150 |
if isinstance(val, np.ndarray):
|
151 |
-
return np.asarray([
|
|
|
|
|
|
|
|
|
152 |
else:
|
153 |
return val
|
154 |
-
|
155 |
|
156 |
def display_match(d):
|
157 |
im = np.concatenate([d.face, d.face_match])
|
@@ -163,10 +166,10 @@ def display_match(d):
|
|
163 |
left=border_size,
|
164 |
right=border_size,
|
165 |
borderType=cv2.BORDER_CONSTANT,
|
166 |
-
value=(255, 255, 120)
|
167 |
)
|
168 |
return border
|
169 |
|
170 |
|
171 |
def rgb(r, g, b):
|
172 |
-
return
|
|
|
110 |
download = True
|
111 |
|
112 |
if download:
|
|
|
113 |
# These are handles to two visual elements to animate.
|
114 |
weights_warning, progress_bar = None, None
|
115 |
try:
|
|
|
143 |
|
144 |
|
145 |
# Function to format floats within a list
|
146 |
+
def format_list(val):
|
147 |
+
if isinstance(val, list):
|
148 |
+
return [format_list(num) for num in val]
|
149 |
if isinstance(val, np.ndarray):
|
150 |
+
return np.asarray([format_list(num) for num in val])
|
151 |
+
if isinstance(val, np.float32):
|
152 |
+
return f"{val:.2f}"
|
153 |
+
if isinstance(val, float):
|
154 |
+
return f"{val:.2f}"
|
155 |
else:
|
156 |
return val
|
157 |
+
|
158 |
|
159 |
def display_match(d):
|
160 |
im = np.concatenate([d.face, d.face_match])
|
|
|
166 |
left=border_size,
|
167 |
right=border_size,
|
168 |
borderType=cv2.BORDER_CONSTANT,
|
169 |
+
value=(255, 255, 120),
|
170 |
)
|
171 |
return border
|
172 |
|
173 |
|
174 |
def rgb(r, g, b):
|
175 |
+
return "#{:02x}{:02x}{:02x}".format(r, g, b)
|