PubMedSearch / app.py
shrut123's picture
Update app.py
febccb7 verified
raw
history blame
3.77 kB
import os
import streamlit as st
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer
import torch
from splade.models.transformer_rep import Splade
from transformers import AutoTokenizer
# Title of the Streamlit App
st.title("Medical Hybrid Search")
# Initialize Pinecone globally
index = None
# Function to initialize Pinecone
def initialize_pinecone():
api_key = os.getenv('PINECONE_API_KEY') # Get Pinecone API key from environment variable
if api_key:
pc = Pinecone(api_key=api_key)
return pc
else:
st.error("Pinecone API key not found! Please set the PINECONE_API_KEY environment variable.")
return None
# Function to connect to the 'pubmed-splade' index
def connect_to_index(pc):
index_name = 'pubmed-splade' # Hardcoded index name
if index_name in pc.list_indexes().names():
index = pc.Index(index_name)
return index
else:
st.error(f"Index '{index_name}' not found!")
return None
# Function to encode query using sentence transformers model
def encode_query(model, query_text):
return model.encode(query_text).tolist()
# Function to create hybrid scaled vectors
def hybrid_scale(dense, sparse, alpha):
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
hsparse = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
}
hdense = [v * alpha for v in dense]
return hdense, hsparse
# Initialize Pinecone
pc = initialize_pinecone()
# If Pinecone initialized successfully, proceed with index management
if pc:
# Connect directly to 'pubmed-splade' index
index = connect_to_index(pc)
# Model for query encoding
model = SentenceTransformer('msmarco-bert-base-dot-v5')
# Initialize sparse model and tokenizer
sparse_model_id = 'naver/splade-cocondenser-ensembledistil'
sparse_model = Splade(sparse_model_id, agg='max')
sparse_model.eval() # Set the model to evaluation mode
tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)
# Query input
query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?")
# Alpha input
alpha = st.slider("Set Alpha (for dense and sparse vector balancing)", 0.0, 1.0, 0.5)
# Button to encode query and search the Pinecone index
if st.button("Search Query"):
if query_text and index:
# Encode query to get dense and sparse vectors
dense_vector = encode_query(model, query_text)
input_ids = tokenizer(query_text, return_tensors='pt')
with torch.no_grad():
sparse_vector = sparse_model(d_kwargs=input_ids.to('cpu'))['d_rep'].squeeze()
# Prepare sparse vector format for Pinecone
indices = sparse_vector.nonzero().squeeze().cpu().tolist()
values = sparse_vector[indices].cpu().tolist()
sparse_dict = {"indices": indices, "values": values}
# Scale dense and sparse vectors
hdense, hsparse = hybrid_scale(dense_vector, sparse_dict, alpha)
# Search the index
results = index.query(
vector=hdense,
sparse_vector=hsparse,
top_k=3,
include_metadata=True
)
st.write("### Search Results:")
for match in results.matches:
st.markdown(f"#### Score: **{match.score:.4f}**")
st.write(f"####Context:{match.metadata.get('context', 'No context available.')}")
st.write("---")
else:
st.error("Please enter a query and ensure the index is initialized.")