kasper-boy's picture
Create app.py
e62d449 verified
raw
history blame
2.57 kB
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
# Load the processor and model
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
def object_detection(image, confidence_threshold):
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform object detection
outputs = model(**inputs)
# Extract bounding boxes and labels
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=confidence_threshold)[0]
# Plot the image with bounding boxes
plt.figure(figsize=(16, 10))
plt.imshow(image)
ax = plt.gca()
detected_objects = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
xmin, ymin, xmax, ymax = box
width, height = xmax - xmin, ymax - ymin
ax.add_patch(plt.Rectangle((xmin, ymin), width, height, fill=False, color='red', linewidth=3))
text = f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}'
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
detected_objects.append(text)
plt.axis('off')
# Save the plot to an image buffer
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# Convert buffer to an Image object
result_image = Image.open(buf)
# Join detected objects into a single string
detected_objects_text = "\n".join(detected_objects)
return result_image, detected_objects_text
# Define the Gradio interface
confidence_slider = gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.5, label="Confidence Threshold")
image_input = gr.inputs.Image(type="pil", label="Upload an Image")
output_image = gr.outputs.Image(type="pil", label="Detected Objects")
output_textbox = gr.outputs.Textbox(label="Detected Objects List")
demo = gr.Interface(
fn=object_detection,
inputs=[image_input, confidence_slider],
outputs=[output_image, output_textbox],
title="Object Detection with DETR (ResNet-101)",
description="Upload an image and adjust the confidence threshold to view detected objects."
)
# Launch the Gradio interface
if __name__ == "__main__":
demo.launch()