kasper-boy's picture
Update app.py
66136c9 verified
raw
history blame
2.56 kB
import torch
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
# Load the processor and model
processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
def object_detection(image, confidence_threshold):
# Convert the input to a PIL Image object if it's not already
if not isinstance(image, Image.Image):
image = Image.open(io.BytesIO(image))
# Preprocess the image
inputs = processor(images=image, return_tensors="pt")
# Perform object detection
outputs = model(**inputs)
# Extract bounding boxes and labels
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=confidence_threshold)[0]
# Plot the image with bounding boxes
plt.figure(figsize=(16, 10))
plt.imshow(image)
ax = plt.gca()
detected_objects = []
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
xmin, ymin, xmax, ymax = box
width, height = xmax - xmin, ymax - ymin
ax.add_patch(plt.Rectangle((xmin, ymin), width, height, fill=False, color='red', linewidth=3))
text = f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}'
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
detected_objects.append(text)
plt.axis('off')
# Save the plot to an image buffer
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
# Convert buffer to an Image object
result_image = Image.open(buf)
# Join detected objects into a single string
detected_objects_text = "\n".join(detected_objects)
return result_image, detected_objects_text
# Define the Gradio interface
demo = gr.Interface(
fn=object_detection,
inputs=[gr.Image(label="Upload an Image"), gr.Slider(minimum=0.0, maximum=1.0, label="Confidence Threshold")],
outputs=[gr.Image(label="Detected Objects"), gr.Textbox(label="Detected Objects List")],
title="Object Detection with DETR (ResNet-101)",
description="Upload an image and get object detection results using the DETR model with a ResNet-101 backbone."
)
# Launch the Gradio interface
if __name__ == "__main__":
demo.launch()