PubMedSearch / app.py
shrut123's picture
Update app.py
fcee9ac verified
raw
history blame
2.83 kB
import os
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from splade.models.transformer_rep import Splade
import pinecone
# Initialize Pinecone connection
api_key = os.getenv('PINECONE_API_KEY')
pinecone.init(api_key=api_key, environment='us-east1-gcp')
index_name = 'pubmed-splade'
# Connect to the Pinecone index
if pinecone.list_indexes() and index_name in pinecone.list_indexes():
index = pinecone.Index(index_name)
else:
st.error("Pinecone index not found! Ensure the correct Pinecone index is being used.")
# Initialize Dense and Sparse models
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Dense model (Sentence-BERT)
dense_model = SentenceTransformer('msmarco-bert-base-dot-v5', device=device)
# Sparse model (SPLADE)
sparse_model_id = 'naver/splade-cocondenser-ensembledistil'
sparse_model = Splade(sparse_model_id, agg='max').to(device)
sparse_model.eval()
# Tokenizer for sparse model
tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)
# Function to encode query into dense and sparse vectors
def encode(text: str):
# Dense vector
dense_vec = dense_model.encode(text).tolist()
# Sparse vector
input_ids = tokenizer(text, return_tensors='pt')
with torch.no_grad():
sparse_vec = sparse_model(d_kwargs=input_ids.to(device))['d_rep'].squeeze()
# Extract non-zero values and indices for sparse vector
indices = sparse_vec.nonzero().squeeze().cpu().tolist()
values = sparse_vec[indices].cpu().tolist()
sparse_dict = {"indices": indices, "values": values}
return dense_vec, sparse_dict
# Function for hybrid search scaling
def hybrid_scale(dense, sparse, alpha: float):
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
hsparse = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
}
hdense = [v * alpha for v in dense]
return hdense, hsparse
# Streamlit UI
st.title("PubMed Search Application")
query = st.text_input("Enter your query:", "")
# Slider to control sparse-dense scaling
alpha = st.slider("Hybrid Search Weight (Dense vs Sparse)", 0.0, 1.0, 0.5)
if query:
# Encode the query
dense_vec, sparse_vec = encode(query)
# Scale vectors based on slider value
hdense, hsparse = hybrid_scale(dense_vec, sparse_vec, alpha)
# Query Pinecone index
response = index.query(vector=hdense, sparse_vector=hsparse, top_k=3, include_metadata=True)
# Display results
st.write(f"Top results for query: **{query}**")
for match in response['matches']:
st.write(f"**Score**: {match['score']}")
st.write(f"**Context**: {match['metadata']['context']}")
st.write("---")