Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import OwlViTProcessor, OwlViTForObjectDetection, ResNetModel | |
from torchvision import transforms | |
from PIL import Image | |
import cv2 | |
import torch.nn.functional as F | |
import tempfile | |
import os | |
# Load models | |
resnet = ResNetModel.from_pretrained("microsoft/resnet-50") | |
resnet.eval() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
resnet = resnet.to(device) | |
mixin = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") | |
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") | |
model = mixin.to(device) | |
# Preprocess the image | |
def preprocess_image(image): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
return transform(image).unsqueeze(0) | |
def extract_embedding(image): | |
image_tensor = preprocess_image(image).to(device) | |
with torch.no_grad(): | |
output = resnet(image_tensor) | |
embedding = output.pooler_output | |
return embedding | |
def cosine_similarity(embedding1, embedding2): | |
return F.cosine_similarity(embedding1, embedding2) | |
def l2_distance(embedding1, embedding2): | |
return torch.norm(embedding1 - embedding2, p=2) | |
def save_array_to_temp_image(arr): | |
rgb_arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB) | |
img = Image.fromarray(rgb_arr) | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
temp_file_name = temp_file.name | |
temp_file.close() | |
img.save(temp_file_name) | |
return temp_file_name | |
def detect_and_crop(target_image, query_image, threshold=0.6, nms_threshold=0.3): | |
target_sizes = torch.Tensor([target_image.size[::-1]]) | |
inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model.image_guided_detection(**inputs) | |
img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB) | |
outputs.logits = outputs.logits.cpu() | |
outputs.target_pred_boxes = outputs.target_pred_boxes.cpu() | |
results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes) | |
boxes, scores = results[0]["boxes"], results[0]["scores"] | |
if len(boxes) == 0: | |
return [] | |
filtered_boxes = [] | |
for box in boxes: | |
x1, y1, x2, y2 = [int(i) for i in box.tolist()] | |
cropped_img = img[y1:y2, x1:x2] | |
if cropped_img.size != 0: | |
filtered_boxes.append(cropped_img) | |
return filtered_boxes | |
def process_video(video_path, query_image, skipframes=0): | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return | |
frame_count = 0 | |
all_results = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if frame_count % (skipframes + 1) == 0: | |
frame_file = save_array_to_temp_image(frame) | |
result_frames = detect_and_crop(Image.open(frame_file), query_image) | |
for res in result_frames: | |
saved_res = save_array_to_temp_image(res) | |
embedding1 = extract_embedding(query_image) | |
embedding2 = extract_embedding(Image.open(saved_res)) | |
dist = l2_distance(embedding1, embedding2).item() | |
cos = cosine_similarity(embedding1, embedding2).item() | |
all_results.append({'l2_dist': dist, 'cos': cos}) | |
frame_count += 1 | |
cap.release() | |
return all_results | |
def process_videos_and_compare(image, video, skipframes=5, threshold=0.47): | |
def median(values): | |
n = len(values) | |
return (values[n // 2 - 1] + values[n // 2]) / 2 if n % 2 == 0 else values[n // 2] | |
results = process_video(video, image, skipframes) | |
if results: | |
l2_dists = [item['l2_dist'] for item in results] | |
cosines = [item['cos'] for item in results] | |
avg_l2_dist = sum(l2_dists) / len(l2_dists) | |
avg_cos = sum(cosines) / len(cosines) | |
median_l2_dist = median(sorted(l2_dists)) | |
median_cos = median(sorted(cosines)) | |
result = { | |
"avg_l2_dist": avg_l2_dist, | |
"avg_cos": avg_cos, | |
"median_l2_dist": median_l2_dist, | |
"median_cos": median_cos, | |
"avg_cos_dist": 1 - avg_cos, | |
"median_cos_dist": 1 - median_cos, | |
"is_present": avg_cos >= threshold | |
} | |
else: | |
result = { | |
"avg_l2_dist": float('inf'), | |
"avg_cos": 0, | |
"median_l2_dist": float('inf'), | |
"median_cos": 0, | |
"avg_cos_dist": float('inf'), | |
"median_cos_dist": float('inf'), | |
"is_present": False | |
} | |
return result | |
def interface(video, image, skipframes, threshold): | |
result = process_videos_and_compare(image, video, skipframes, threshold) | |
return result | |
iface = gr.Interface( | |
fn=interface, | |
inputs=[ | |
gr.Video(label="Upload a Video"), | |
gr.Image(type="pil", label="Upload a Query Image"), | |
gr.Slider(minimum=0, maximum=10, step=1, value=5, label="Skip Frames"), | |
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.47, label="Threshold") | |
], | |
outputs=[ | |
gr.JSON(label="Result") | |
], | |
title="Object Detection in Video", | |
description=""" | |
**Instructions:** | |
1. **Upload a Video**: Select a video file to upload. | |
2. **Upload a Query Image**: Select an image file that contains the object you want to detect in the video. | |
3. **Set Skip Frames**: Adjust the slider to set the number of frames to skip between each processing. | |
4. **Set Threshold**: Adjust the slider to set the threshold for cosine similarity to determine if the object is present in the video. | |
5. **View Results**: The result will show the average and median distances and similarities, and whether the object is present in the video based on the threshold. | |
""" | |
) | |
if __name__ == "__main__": | |
iface.launch() |