import os import streamlit as st import torch from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer from splade.models.transformer_rep import Splade import pinecone # Initialize Pinecone connection api_key = os.getenv('PINECONE_API_KEY') pinecone.init(api_key=api_key, environment='us-east1-gcp') index_name = 'pubmed-splade' # Connect to the Pinecone index if pinecone.list_indexes() and index_name in pinecone.list_indexes(): index = pinecone.Index(index_name) else: st.error("Pinecone index not found! Ensure the correct Pinecone index is being used.") # Initialize Dense and Sparse models device = 'cuda' if torch.cuda.is_available() else 'cpu' # Dense model (Sentence-BERT) dense_model = SentenceTransformer('msmarco-bert-base-dot-v5', device=device) # Sparse model (SPLADE) sparse_model_id = 'naver/splade-cocondenser-ensembledistil' sparse_model = Splade(sparse_model_id, agg='max').to(device) sparse_model.eval() # Tokenizer for sparse model tokenizer = AutoTokenizer.from_pretrained(sparse_model_id) # Function to encode query into dense and sparse vectors def encode(text: str): # Dense vector dense_vec = dense_model.encode(text).tolist() # Sparse vector input_ids = tokenizer(text, return_tensors='pt') with torch.no_grad(): sparse_vec = sparse_model(d_kwargs=input_ids.to(device))['d_rep'].squeeze() # Extract non-zero values and indices for sparse vector indices = sparse_vec.nonzero().squeeze().cpu().tolist() values = sparse_vec[indices].cpu().tolist() sparse_dict = {"indices": indices, "values": values} return dense_vec, sparse_dict # Function for hybrid search scaling def hybrid_scale(dense, sparse, alpha: float): if alpha < 0 or alpha > 1: raise ValueError("Alpha must be between 0 and 1") hsparse = { 'indices': sparse['indices'], 'values': [v * (1 - alpha) for v in sparse['values']] } hdense = [v * alpha for v in dense] return hdense, hsparse # Streamlit UI st.title("PubMed Search Application") query = st.text_input("Enter your query:", "") # Slider to control sparse-dense scaling alpha = st.slider("Hybrid Search Weight (Dense vs Sparse)", 0.0, 1.0, 0.5) if query: # Encode the query dense_vec, sparse_vec = encode(query) # Scale vectors based on slider value hdense, hsparse = hybrid_scale(dense_vec, sparse_vec, alpha) # Query Pinecone index response = index.query(vector=hdense, sparse_vector=hsparse, top_k=3, include_metadata=True) # Display results st.write(f"Top results for query: **{query}**") for match in response['matches']: st.write(f"**Score**: {match['score']}") st.write(f"**Context**: {match['metadata']['context']}") st.write("---")