Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pinecone | |
from sentence_transformers import SentenceTransformer | |
import torch | |
from splade.models.transformer_rep import Splade | |
from transformers import AutoTokenizer | |
from datasets import load_dataset | |
pinecone.init( | |
api_key='884344f6-d820-4bc8-9edf-4157373df452', | |
environment='gcp-starter' | |
) | |
index = pinecone.Index('pubmed-splade') | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# check device being run on | |
if device != 'cuda': | |
print("==========\n"+ | |
"WARNING: You are not running on GPU so this may be slow.\n"+ | |
"\n==========") | |
dense_model = SentenceTransformer( | |
'msmarco-bert-base-dot-v5', | |
device=device | |
) | |
sparse_model_id = 'naver/splade-cocondenser-ensembledistil' | |
sparse_model = Splade(sparse_model_id, agg='max') | |
sparse_model.to(device) # move to GPU if possible | |
sparse_model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(sparse_model_id) | |
data = load_dataset('Binaryy/cream_listings', split='train') | |
df = data.to_pandas() | |
def encode(text: str): | |
# create dense vec | |
dense_vec = dense_model.encode(text).tolist() | |
# create sparse vec | |
input_ids = tokenizer(text, return_tensors='pt') | |
with torch.no_grad(): | |
sparse_vec = sparse_model( | |
d_kwargs=input_ids.to(device) | |
)['d_rep'].squeeze() | |
# convert to dictionary format | |
indices = sparse_vec.nonzero().squeeze().cpu().tolist() | |
values = sparse_vec[indices].cpu().tolist() | |
sparse_dict = {"indices": indices, "values": values} | |
# return vecs | |
return dense_vec, sparse_dict | |
def search(query): | |
dense, sparse = encode(query) | |
# query | |
xc = index.query( | |
vector=dense, | |
sparse_vector=sparse, | |
top_k=5, # how many results to return | |
include_metadata=True | |
) | |
match_ids = [match['id'].split('-')[0] for match in xc['matches']] | |
# Query the existing DataFrame based on 'id' | |
filtered_df = df[df['_id'].isin(match_ids)] | |
attributes_to_extract = ['_id', 'postedBy.accountName', 'images', 'title', 'location', 'price'] | |
extracted_data = filtered_df[attributes_to_extract] | |
result_json = extracted_data.to_json(orient='records') | |
return result_json | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=search, | |
inputs="text", | |
outputs="json", | |
title="Semantic Search Prototype", | |
description="Enter your query to perform a semantic search.", | |
) | |
# Launch the Gradio interface | |
iface.launch(share=True) |