Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
import random | |
import numpy as np | |
from transformers import CLIPProcessor, CLIPModel | |
from os import environ | |
import clip | |
import pickle | |
import requests | |
import torch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# # Load the pre-trained model and processor | |
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False) | |
# Load the Unsplash dataset | |
dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split | |
height = 256 # height for resizing images | |
def predict(image, labels): | |
inputs = processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities | |
return {k: float(v) for k, v in zip(labels, probs[0])} | |
def predict2(image, labels): | |
image = orig_clip_processor(img).unsqueeze(0).to(device) | |
text = clip.tokenize(labels).to(device) | |
with torch.no_grad(): | |
image_features = orig_clip_model.encode_image(image) | |
text_features = orig_clip_model.encode_text(text) | |
logits_per_image, logits_per_text = orig_clip_model(image, text) | |
probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
return {k: float(v) for k, v in zip(labels, probs[0])} | |
def rand_image(): | |
n = dataset.num_rows | |
r = random.randrange(0,n) | |
return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image | |
def set_labels(text): | |
return text.split(",") | |
get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"]) | |
def generate_text(image, model_name): | |
return get_caption(image, model_name) | |
# get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"]) | |
# def search_images(text): | |
# return get_images(text, api_name="images") | |
emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl' | |
with open(emb_filename, 'rb') as emb: | |
id2url, img_names, img_emb = pickle.load(emb) | |
def search(search_query): | |
with torch.no_grad(): | |
# Encode and normalize the description using CLIP | |
text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query)) | |
text_encoded /= text_encoded.norm(dim=-1, keepdim=True) | |
# Retrieve the description vector | |
text_features = text_encoded.cpu().numpy() | |
# Compute the similarity between the descrption and each photo using the Cosine similarity | |
similarities = (text_features @ img_emb.T).squeeze(0) | |
# Sort the photos by their similarity score | |
best_photos = similarities.argsort()[::-1] | |
best_photos = best_photos[:15] | |
#best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True) | |
best_photo_ids = img_names[best_photos] | |
imgs = [] | |
# Iterate over the top 5 results | |
for id in best_photo_ids: | |
id, _ = id.split('.') | |
url = id2url.get(id, "") | |
if url == "": continue | |
img = url + "?h=512" | |
# r = requests.get(url + "?w=512", stream=True) | |
# img = Image.open(r.raw) | |
#credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>' | |
imgs.append(img) | |
#display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>')) | |
if len(imgs) == 5: break | |
return imgs | |
with gr.Blocks() as demo: | |
with gr.Tab("Zero-Shot Classification"): | |
labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list | |
instructions = """## Instructions: | |
1. Enter list of labels separated by commas (or select one of the examples below) | |
2. Click **Get Random Image** to grab a random image from dataset and analyze it against the labels | |
3. Click **Re-Classify Image** to re-run classification on current image after changing labels""" | |
gr.Markdown(instructions) | |
with gr.Row(variant="compact"): | |
label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False) | |
#submit_btn = gr.Button("Submit").style(full_width=False) | |
gr.Examples(["spring, summer, fall, winter", | |
"mountain, city, beach, ocean, desert, forest, valley", | |
"red, blue, green, white, black, purple, brown", | |
"person, animal, landscape, something else", | |
"day, night, dawn, dusk"], inputs=label_text) | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
im = gr.Image(interactive=False).style(height=height) | |
with gr.Row(): | |
get_btn = gr.Button("Get Random Image").style(full_width=False) | |
reclass_btn = gr.Button("Re-Classify Image").style(full_width=False) | |
cf = gr.Label() | |
#submit_btn.click(fn=set_labels, inputs=label_text) | |
label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed | |
label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification | |
label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification | |
get_btn.click(fn=rand_image, outputs=im) | |
im.change(predict2, inputs=[im, labels], outputs=cf) | |
reclass_btn.click(predict2, inputs=[im, labels], outputs=cf) | |
with gr.Tab("Image Captioning"): | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
im_cap = gr.Image(interactive=False, type='filepath').style(height=height) | |
model_name = gr.Radio(choices=["COCO","Conceptual captions"], type="value", value="COCO", label="Model").style(container=True, item_container = False) | |
with gr.Row(): | |
get_btn_cap = gr.Button("Get Random Image").style(full_width=False) | |
caption_btn = gr.Button("Create Caption").style(full_width=False) | |
caption = gr.Textbox(label='Caption') | |
get_btn_cap.click(fn=rand_image, outputs=im_cap) | |
#im_cap.change(generate_text, inputs=im_cap, outputs=caption) | |
caption_btn.click(generate_text, inputs=[im_cap, model_name], outputs=caption) | |
with gr.Tab("Image Search"): | |
with gr.Column(variant="panel"): | |
desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False) | |
search_btn = gr.Button("Find Images").style(full_width=False) | |
gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5)) | |
search_btn.click(search,inputs=desc, outputs=gallery) | |
demo.launch() |