kasper-boy commited on
Commit
e62d449
1 Parent(s): 7dfd1d1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.patches as patches
7
+ import io
8
+
9
+ # Load the processor and model
10
+ processor = DetrImageProcessor.from_pretrained('facebook/detr-resnet-101')
11
+ model = DetrForObjectDetection.from_pretrained('facebook/detr-resnet-101')
12
+
13
+ def object_detection(image, confidence_threshold):
14
+ # Preprocess the image
15
+ inputs = processor(images=image, return_tensors="pt")
16
+
17
+ # Perform object detection
18
+ outputs = model(**inputs)
19
+
20
+ # Extract bounding boxes and labels
21
+ target_sizes = torch.tensor([image.size[::-1]])
22
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=confidence_threshold)[0]
23
+
24
+ # Plot the image with bounding boxes
25
+ plt.figure(figsize=(16, 10))
26
+ plt.imshow(image)
27
+ ax = plt.gca()
28
+
29
+ detected_objects = []
30
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
31
+ box = [round(i, 2) for i in box.tolist()]
32
+ xmin, ymin, xmax, ymax = box
33
+ width, height = xmax - xmin, ymax - ymin
34
+
35
+ ax.add_patch(plt.Rectangle((xmin, ymin), width, height, fill=False, color='red', linewidth=3))
36
+ text = f'{model.config.id2label[label.item()]}: {round(score.item(), 3)}'
37
+ ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
38
+ detected_objects.append(text)
39
+
40
+ plt.axis('off')
41
+
42
+ # Save the plot to an image buffer
43
+ buf = io.BytesIO()
44
+ plt.savefig(buf, format='png')
45
+ buf.seek(0)
46
+ plt.close()
47
+
48
+ # Convert buffer to an Image object
49
+ result_image = Image.open(buf)
50
+
51
+ # Join detected objects into a single string
52
+ detected_objects_text = "\n".join(detected_objects)
53
+
54
+ return result_image, detected_objects_text
55
+
56
+ # Define the Gradio interface
57
+ confidence_slider = gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.5, label="Confidence Threshold")
58
+ image_input = gr.inputs.Image(type="pil", label="Upload an Image")
59
+ output_image = gr.outputs.Image(type="pil", label="Detected Objects")
60
+ output_textbox = gr.outputs.Textbox(label="Detected Objects List")
61
+
62
+ demo = gr.Interface(
63
+ fn=object_detection,
64
+ inputs=[image_input, confidence_slider],
65
+ outputs=[output_image, output_textbox],
66
+ title="Object Detection with DETR (ResNet-101)",
67
+ description="Upload an image and adjust the confidence threshold to view detected objects."
68
+ )
69
+
70
+ # Launch the Gradio interface
71
+ if __name__ == "__main__":
72
+ demo.launch()