clement-w's picture
Fix multiple object detection
aa2ee68
raw
history blame
6.37 kB
import pathlib
import validators
import requests
import gradio as gr
# For running inference on the TF-Hub module.
import tensorflow as tf
# For downloading the image.
# For drawing onto the image.
import numpy as np
from PIL import Image
from PIL import ImageColor
from PIL import ImageDraw
from PIL import ImageFont
print("load model...")
detector = tf.saved_model.load("model/saved_model")
def draw_bounding_box_on_image(image,
ymin,
xmin,
ymax,
xmax,
color,
font,
thickness=4,
display_str_list=()):
"""Adds a bounding box to an image."""
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
(left, right, top, bottom) = (xmin * im_width, xmax * im_width,
ymin * im_height, ymax * im_height)
draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=thickness,
fill=color)
# If the total height of the display strings added to the top of the bounding
# box exceeds the top of the image, stack the strings below the bounding box
# instead of above.
display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
# Each display_str has a top and bottom margin of 0.05x.
total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
if top > total_display_str_height:
text_bottom = top
else:
text_bottom = top + total_display_str_height
# Reverse list and print from bottom to top.
for display_str in display_str_list[::-1]:
text_width, text_height = font.getsize(display_str)
margin = np.ceil(0.05 * text_height)
draw.rectangle([(left, text_bottom - text_height - 2 * margin),
(left + text_width, text_bottom)],
fill=color)
draw.text((left + margin, text_bottom - text_height - margin),
display_str,
fill="black",
font=font)
text_bottom -= text_height - 2 * margin
"""Overlay labeled boxes on an image with formatted scores and label names."""
def draw_boxes(image, boxes, class_names, scores, max_boxes=10, min_score=0.1):
colors = list(ImageColor.colormap.values())
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf",
25)
except IOError:
print("Font not found, using default font.")
font = ImageFont.load_default()
for i in range(min(boxes.shape[1], max_boxes)):
if scores[0][i] >= min_score:
ymin, xmin, ymax, xmax = tuple(boxes[0][i])
display_str = "{}: {}%".format(class_names[i],
int(100 * scores[0][i]))
color = colors[hash(class_names[i]) % len(colors)]
image_pil = Image.fromarray(np.uint8(image)).convert("RGB")
draw_bounding_box_on_image(
image_pil,
ymin,
xmin,
ymax,
xmax,
color,
font,
display_str_list=[display_str])
np.copyto(image, np.array(image_pil))
return image
def run_detector(url_input, image_input, minscore=0.1):
if (validators.url(url_input)):
img = Image.open(requests.get(url_input, stream=True).raw)
elif (image_input):
img = image_input
converted_img = tf.image.convert_image_dtype(img, tf.uint8)[
tf.newaxis, ...]
result = detector(converted_img)
result = {key: value.numpy() for key, value in result.items()}
labels = ["cyclist" for _ in range(len(result["detection_scores"][0]))]
image_with_boxes = draw_boxes(
np.array(img), result["detection_boxes"],
labels, result["detection_scores"], min_score=minscore)
return image_with_boxes
css = '''
h1#title {
text-align: center;
}
'''
demo = gr.Blocks(css=css)
title = """<h1 id="title">Custom Cyclists detector</h1>"""
description = "todo"
def set_example_image(example: list) -> dict:
return gr.Image.update(value=example[0])
def set_example_url(example: list) -> dict:
return gr.Textbox.update(value=example[0])
urls = ["https://hips.hearstapps.com/hmg-prod.s3.amazonaws.com/images/cyclist-on-path-by-sea-royalty-free-image-1656931301.jpg?crop=0.727xw:0.699xh;0.134xw,0.169xh&resize=640:*"]
with demo:
gr.Markdown(title)
gr.Markdown(description)
slider_input = gr.Slider(minimum=0.0, maximum=1,
value=0.2, label='Prediction Threshold')
with gr.Tabs():
with gr.TabItem('Image URL'):
with gr.Row():
url_input = gr.Textbox(
lines=2, label='Enter valid image URL here..')
img_output_from_url = gr.Image(shape=(640, 640))
with gr.Row():
example_url = gr.Dataset(components=[url_input], samples=[
[str(url)] for url in urls])
url_but = gr.Button('Detect')
with gr.TabItem('Image Upload'):
with gr.Row():
img_input = gr.Image(type='pil')
img_output_from_upload = gr.Image(shape=(650, 650))
with gr.Row():
example_images = gr.Dataset(components=[img_input],
samples=[[path.as_posix()]
for path in sorted(pathlib.Path('images').rglob('*.jpg'))])
img_but = gr.Button('Detect')
url_but.click(run_detector, inputs=[
url_input, img_input, slider_input], outputs=img_output_from_url, queue=True)
img_but.click(run_detector, inputs=[
url_input, img_input, slider_input], outputs=img_output_from_upload, queue=True)
example_images.click(fn=set_example_image, inputs=[
example_images], outputs=[img_input])
example_url.click(fn=set_example_url, inputs=[
example_url], outputs=[url_input])
demo.launch(enable_queue=True)