import os import streamlit as st from pinecone import Pinecone from sentence_transformers import SentenceTransformer import torch from splade.models.transformer_rep import Splade from transformers import AutoTokenizer # Title of the Streamlit App st.title("Medical Hybrid Search") # Initialize Pinecone globally index = None # Function to initialize Pinecone def initialize_pinecone(): api_key = os.getenv('PINECONE_API_KEY') # Get Pinecone API key from environment variable if api_key: pc = Pinecone(api_key=api_key) return pc else: st.error("Pinecone API key not found! Please set the PINECONE_API_KEY environment variable.") return None # Function to connect to the 'pubmed-splade' index def connect_to_index(pc): index_name = 'pubmed-splade' # Hardcoded index name if index_name in pc.list_indexes().names(): index = pc.Index(index_name) return index else: st.error(f"Index '{index_name}' not found!") return None # Function to encode query using sentence transformers model def encode_query(model, query_text): return model.encode(query_text).tolist() # Function to create hybrid scaled vectors def hybrid_scale(dense, sparse, alpha): if alpha < 0 or alpha > 1: raise ValueError("Alpha must be between 0 and 1") hsparse = { 'indices': sparse['indices'], 'values': [v * (1 - alpha) for v in sparse['values']] } hdense = [v * alpha for v in dense] return hdense, hsparse # Initialize Pinecone pc = initialize_pinecone() # If Pinecone initialized successfully, proceed with index management if pc: # Connect directly to 'pubmed-splade' index index = connect_to_index(pc) # Model for query encoding model = SentenceTransformer('msmarco-bert-base-dot-v5') # Initialize sparse model and tokenizer sparse_model_id = 'naver/splade-cocondenser-ensembledistil' sparse_model = Splade(sparse_model_id, agg='max') sparse_model.eval() # Set the model to evaluation mode tokenizer = AutoTokenizer.from_pretrained(sparse_model_id) # Query input query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?") # Alpha input alpha = st.slider("Set Alpha (for dense and sparse vector balancing)", 0.0, 1.0, 0.5) # Button to encode query and search the Pinecone index if st.button("Search Query"): if query_text and index: # Encode query to get dense and sparse vectors dense_vector = encode_query(model, query_text) input_ids = tokenizer(query_text, return_tensors='pt') with torch.no_grad(): sparse_vector = sparse_model(d_kwargs=input_ids.to('cpu'))['d_rep'].squeeze() # Prepare sparse vector format for Pinecone indices = sparse_vector.nonzero().squeeze().cpu().tolist() values = sparse_vector[indices].cpu().tolist() sparse_dict = {"indices": indices, "values": values} # Scale dense and sparse vectors hdense, hsparse = hybrid_scale(dense_vector, sparse_dict, alpha) # Search the index results = index.query( vector=hdense, sparse_vector=hsparse, top_k=3, include_metadata=True ) st.write("Results:") for match in results.matches: st.markdown(f"Score: {match.score:.4f}") st.write(f"Answer: {match.metadata.get('context', 'No context available.')}") st.write("---") else: st.error("Please enter a query and ensure the index is initialized.")