Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pinecone | |
from sentence_transformers import SentenceTransformer | |
import torch | |
from splade.models.transformer_rep import Splade | |
from transformers import AutoTokenizer | |
from datasets import load_dataset | |
import os | |
from pinecone import Pinecone | |
os.environ['PINECONE_API_KEY'] = '884344f6-d820-4bc8-9edf-4157373df452' | |
pc = Pinecone(api_key=os.environ.get('PINECONE_API_KEY')) | |
index = pc.Index('pubmed-splade') | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# check device being run on | |
if device != 'cuda': | |
print("==========\n"+ | |
"WARNING: You are not running on GPU so this may be slow.\n"+ | |
"\n==========") | |
dense_model = SentenceTransformer( | |
'msmarco-bert-base-dot-v5', | |
device=device | |
) | |
sparse_model_id = 'naver/splade-cocondenser-ensembledistil' | |
sparse_model = Splade(sparse_model_id, agg='max') | |
sparse_model.to(device) # moves to GPU if possible | |
sparse_model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(sparse_model_id) | |
data = load_dataset('Binaryy/cream_listings', split='train') | |
df = data.to_pandas() | |
def encode(text: str): | |
# create dense vec | |
dense_vec = dense_model.encode(text).tolist() | |
# create sparse vec | |
input_ids = tokenizer(text, return_tensors='pt') | |
with torch.no_grad(): | |
sparse_vec = sparse_model( | |
d_kwargs=input_ids.to(device) | |
)['d_rep'].squeeze() | |
# convert to dictionary format | |
indices = sparse_vec.nonzero().squeeze().cpu().tolist() | |
values = sparse_vec[indices].cpu().tolist() | |
sparse_dict = {"indices": indices, "values": values} | |
# return vecss | |
return dense_vec, sparse_dict | |
def search(query): | |
dense, sparse = encode(query) | |
# query | |
xc = index.query( | |
vector=dense, | |
sparse_vector=sparse, | |
top_k=5, # how many results to return | |
include_metadata=True | |
) | |
match_ids = [match['id'].split('-')[0] for match in xc['matches']] | |
# Query the existing DataFrame based on 'id' | |
filtered_df = df[df['_id'].isin(match_ids)] | |
attributes_to_extract = ['_id', 'title', 'location', 'features', 'description', 'images', | |
'videos', 'available', 'price', 'attachedDocument', 'year', | |
'carCondition', 'engineType', 'colour', 'model', 'noOfBed', | |
'noOfBathroom', 'locationISO', 'forRent', 'views', 'thoseWhoSaved', | |
'createdAt', 'updatedAt', '__v', 'category._id', 'category.title', | |
'category.slug', 'category.isAdminAllowed', 'category.createdAt', | |
'category.updatedAt', 'category.__v', 'postedBy.pageViews.value', | |
'postedBy.pageViews.users', 'postedBy.totalSaved.value', | |
'postedBy.totalSaved.users', 'postedBy._id', 'postedBy.firstName', | |
'postedBy.lastName', 'postedBy.about', 'postedBy.cover', | |
'postedBy.email', 'postedBy.password', 'postedBy.isAdmin', | |
'postedBy.savedListing', 'postedBy.isVerified', | |
'postedBy.verifiedProfilePicture', 'postedBy.profilePicture', | |
'postedBy.pronoun', 'postedBy.userType', 'postedBy.accountType', | |
'postedBy.subscribed', 'postedBy.noOfSubscription', | |
'postedBy.totalListing', 'postedBy.sellerType', 'postedBy.createdAt', | |
'postedBy.updatedAt', 'postedBy.__v', 'postedBy.address', | |
'postedBy.city', 'postedBy.country', 'postedBy.gender', | |
'postedBy.nationality', 'postedBy.verificationType', 'postedBy.dob', | |
'postedBy.locationISO', 'postedBy.state', 'postedBy.zipCode', | |
'postedBy.otherNames', 'postedBy.facebookUrl', 'postedBy.instagramUrl', | |
'postedBy.phoneNumber1', 'postedBy.phoneNumber2', 'postedBy.websiteUrl', | |
'postedBy.accountName', 'postedBy.accountNo', 'postedBy.bankName', | |
'string_features', 'complete_description'] | |
extracted_data = filtered_df[attributes_to_extract] | |
result_json = extracted_data.to_json(orient='records') | |
return result_json | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=search, | |
inputs="text", | |
outputs="json", | |
title="Semantic Search Prototype", | |
description="Enter your query to search.", | |
) | |
# Launch Gradio interface | |
iface.launch(share=True) |