vsrinivas's picture
Update app.py
b8dc911 verified
import gradio as gr
from io import BytesIO
from base64 import b64encode
from pinecone_text.sparse import BM25Encoder
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import os
import re
####################
import pandas as pd
##########################
model = SentenceTransformer('sentence-transformers/clip-ViT-B-32')
fashion = load_dataset("ashraq/fashion-product-images-small", split="train")
###############
fashion_df = pd.DataFrame(fashion)
####################
images = fashion['image']
metadata = fashion.remove_columns('image')
item_list = list(set(metadata['productDisplayName']))
INDEX_NAME = 'srinivas-hybrid-search'
PINECONE_API_KEY = os.getenv('pinecone_api_key')
pinecone = Pinecone(api_key=PINECONE_API_KEY)
index = pinecone.Index(INDEX_NAME)
bm25 = BM25Encoder()
bm25.fit(metadata['productDisplayName'])
def display_result(image_batch, match_batch):
figures = []
for img, title in zip(image_batch, match_batch):
if img.mode != 'RGB':
img = img.convert('RGB')
b = BytesIO()
img.save(b, format='PNG')
img_str = b64encode(b.getvalue()).decode('utf-8')
figures.append(f'''
<figure style="margin: 0; padding: 0; text-align: left;">
<figcaption style="font-weight: bold; margin:0;">{title}</figcaption>
<img src="data:image/png;base64,{img_str}" style="width: 180px; height: 240px; margin: 0;" >
</figure>
''')
html_content = f'''
<div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 20px; align-items: start;">
{''.join(figures)}
</div>
'''
return html_content
def hybrid_scale(dense, sparse, alpha):
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
hsparse = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
}
hdense = [v * alpha for v in dense]
return hdense, hsparse
def process_input(query, slider_value):
#####################
query=query
print(f"Query: {query}")
search_words = query.lower().split()
# pattern = r"(?=.*\b" + r"\b)(?=.*\b".join(map(re.escape, search_words)) + r"\b)"
pattern = r"(?=.*" + r")(?=.*".join(map(re.escape, search_words)) + r")"
filtered_items = [item for item in item_list if re.search(pattern, item.lower())]
print(f"Filtered items: {filtered_items}")
filtered_df = fashion_df[fashion_df['productDisplayName'].isin(filtered_items)]
#####################
try:
slider_value = float(slider_value)
sparse = bm25.encode_queries(query)
dense = model.encode(query).tolist()
hdense, hsparse = hybrid_scale(dense, sparse, slider_value)
result = index.query(
top_k=12,
vector=hdense,
sparse_vector=hsparse,
include_metadata=True
)
imgs = [images[int(r["id"])] for r in result["matches"]]
matches = [x["metadata"]['productDisplayName'] for x in result["matches"]]
##########
if query in filtered_items:
# exact_match = filtered_df.loc[filtered_df['productDisplayName']==query, 'productDisplayName'].iat[0]
exact_img = filtered_df.loc[filtered_df['productDisplayName']==query, 'image'].iat[0]
imgs.insert(0, exact_img)
matches.insert(0, query)
##########
print(f"No. of matching images: {len(imgs)}")
print(matches)
return display_result(imgs, matches)
except Exception as e:
return f"<p style='color:red;'>Not found. Try another search</p>"
def update_textbox(choice):
return choice
def text_process(search_string):
search_words = search_string.title().split()
# pattern = r"(?=.*\b" + r"\b)(?=.*\b".join(map(re.escape, search_words)) + r"\b)"
pattern = r"(?=.*" + r")(?=.*".join(map(re.escape, search_words)) + r")"
filtered_items = [item for item in item_list if re.search(pattern, item)]
return gr.update(visible=True), gr.update(choices=filtered_items, value=filtered_items[0] if filtered_items else "")
with gr.Blocks() as demo:
gr.Markdown("# Get Fashion Items Recommended Based On Your Search..\n"
"## Recommender System implemented based Pinecone Vector Database with Dense & Sparse Embeddings and Hybrid Search..")
with gr.Row():
text_input = gr.Textbox(label="Type-in what you are looking for..")
submit_btn = gr.Button("Click this button for further filtering..")
dropdown = gr.Dropdown(label="Click here and select to narrow your serach..",
value= "Select an item from this list or start typing", allow_custom_value=True, interactive=True, visible=False)
slider = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5, label="Adjust the Slider to get better recommendations that suit what you are looking for..", interactive=True)
dropdown.change(fn=update_textbox, inputs=dropdown, outputs=text_input)
html_output = gr.HTML(label="Relevant Images")
submit_btn.click(fn=text_process, inputs=[text_input], outputs=[dropdown, dropdown])
text_input.change(fn=process_input, inputs=[text_input, slider], outputs=html_output)
slider.change(fn=process_input, inputs=[text_input, slider], outputs=html_output)
demo.launch(debug=True, share=True)