Spaces:
Sleeping
Sleeping
import gradio as gr | |
from io import BytesIO | |
from base64 import b64encode | |
from pinecone_text.sparse import BM25Encoder | |
from pinecone import Pinecone | |
from sentence_transformers import SentenceTransformer | |
from datasets import load_dataset | |
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32') | |
fashion = load_dataset("ashraq/fashion-product-images-small", split="train") | |
images = fashion['image'] | |
metadata = fashion.remove_columns('image') | |
item_list = list(set(metadata['productDisplayName'])) | |
INDEX_NAME = 'srinivas-hybrid-search' | |
PINECONE_API_KEY = os.getenv(pinecone_api_key) | |
pinecone = Pinecone(api_key=PINECONE_API_KEY) | |
index = pinecone.Index(INDEX_NAME) | |
bm25 = BM25Encoder() | |
bm25.fit(metadata['productDisplayName']) | |
# Function to display images in a grid layout | |
def display_result(image_batch, match_batch): | |
figures = [] | |
for img, title in zip(image_batch, match_batch): | |
# Ensure the image is in the correct format for encoding | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
# Convert image to bytes and encode as base64 | |
b = BytesIO() | |
img.save(b, format='PNG') | |
img_str = b64encode(b.getvalue()).decode('utf-8') | |
# Create HTML figure element with the image title | |
figures.append(f''' | |
<figure style="margin: 0; padding: 0; text-align: left;"> | |
<figcaption style="font-weight: bold; margin:0;">{title}</figcaption> | |
<img src="data:image/png;base64,{img_str}" style="width: 180px; height: 240px; margin: 0;" > | |
</figure> | |
''') | |
# Combine all figures into a single HTML string with reduced spacing | |
html_content = f''' | |
<div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 20px; align-items: start;"> | |
{''.join(figures)} | |
</div> | |
''' | |
return html_content | |
# Function to scale vectors based on alpha for hybrid search | |
def hybrid_scale(dense, sparse, alpha): | |
if alpha < 0 or alpha > 1: | |
raise ValueError("Alpha must be between 0 and 1") | |
# Scale sparse and dense vectors to create hybrid search vectors | |
hsparse = { | |
'indices': sparse['indices'], | |
'values': [v * (1 - alpha) for v in sparse['values']] | |
} | |
hdense = [v * alpha for v in dense] | |
return hdense, hsparse | |
# Function to process the input text and slider value, with error handling | |
def process_input(query, slider_value): | |
try: | |
slider_value = float(slider_value) | |
sparse = bm25.encode_queries(query) | |
dense = model.encode(query).tolist() | |
hdense, hsparse = hybrid_scale(dense, sparse, slider_value) | |
result = index.query( | |
top_k=12, | |
vector=hdense, # Use hybrid dense vector | |
sparse_vector=hsparse, # Use hybrid sparse vector | |
include_metadata=True | |
) | |
imgs = [images[int(r["id"])] for r in result["matches"]] | |
matches = [x["metadata"]['productDisplayName'] for x in result["matches"]] | |
print(f"No. of matching images: {len(imgs)}") | |
print(matches) | |
return display_result(imgs, matches) | |
except Exception as e: | |
# Handle exceptions and return a friendly error message | |
return f"<p style='color:red;'>Not found. Try another search: {str(e)}</p>" | |
# Function to update the textbox value when a dropdown choice is selected | |
def update_textbox(choice): | |
return choice | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Search for Your Fashion Item") | |
with gr.Row(): | |
dropdown = gr.Dropdown(choices=item_list, label="Select an item from here..", value= "Select an item from this list or start typing", interactive=True) | |
text_input = gr.Textbox(label="Alternatively, enter item text..", value="Type-in what you are looking for", interactive=True) | |
slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Adjust the Slider to get better results that suit what you are looking for..", interactive=True) | |
# Automatically update the text input when a dropdown selection is made | |
dropdown.change(fn=update_textbox, inputs=dropdown, outputs=text_input) | |
# HTML output box to display images | |
html_output = gr.HTML(label="Relevant Images") | |
# Process and display images based on text input or slider changes | |
text_input.change(fn=process_input, inputs=[text_input, slider], outputs=html_output) | |
slider.change(fn=process_input, inputs=[text_input, slider], outputs=html_output) | |
demo.launch(debug=True, share=True) |