File size: 3,749 Bytes
e2b9039 3ecbeff e2b9039 9a6f924 dd92eea e2b9039 b21411e 9610c39 e2b9039 d792c52 b21411e 9610c39 b21411e e2b9039 3ecbeff 9610c39 b21411e 3ecbeff e2b9039 b21411e e2b9039 9a6f924 b21411e fff69e7 e2b9039 b21411e 9610c39 3ecbeff b0a56a3 9610c39 b21411e dd92eea b0a56a3 b21411e b0a56a3 9a6f924 b21411e 9a6f924 b21411e dd92eea 9a6f924 febccb7 9a6f924 a19ad68 b21411e 9a6f924 9610c39 b21411e e740c3d 9610c39 badd3d7 e740c3d 9610c39 b21411e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import streamlit as st
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer
import torch
from splade.models.transformer_rep import Splade
from transformers import AutoTokenizer
# Title of the Streamlit App
st.title("Medical Hybrid Search")
# Initialize Pinecone globally
index = None
# Function to initialize Pinecone
def initialize_pinecone():
api_key = os.getenv('PINECONE_API_KEY') # Get Pinecone API key from environment variable
if api_key:
pc = Pinecone(api_key=api_key)
return pc
st.error("Pinecone API key not found! Please set the PINECONE_API_KEY environment variable.")
return None
# Function to connect to the 'pubmed-splade' index
def connect_to_index(pc):
index_name = 'pubmed-splade' # Hardcoded index name
if index_name in pc.list_indexes().names():
index = pc.Index(index_name)
return index
st.error(f"Index '{index_name}' not found!")
return None
# Function to encode query using sentence transformers model
def encode_query(model, query_text):
return model.encode(query_text).tolist()
# Function to create hybrid scaled vectors
def hybrid_scale(dense, sparse, alpha):
if alpha < 0 or alpha > 1:
raise ValueError("Alpha must be between 0 and 1")
hsparse = {
'indices': sparse['indices'],
'values': [v * (1 - alpha) for v in sparse['values']]
hdense = [v * alpha for v in dense]
return hdense, hsparse
# Initialize Pinecone
pc = initialize_pinecone()
# If Pinecone initialized successfully, proceed with index management
if pc:
# Connect directly to 'pubmed-splade' index
index = connect_to_index(pc)
# Model for query encoding
model = SentenceTransformer('msmarco-bert-base-dot-v5')
# Initialize sparse model and tokenizer
sparse_model_id = 'naver/splade-cocondenser-ensembledistil'
sparse_model = Splade(sparse_model_id, agg='max')
sparse_model.eval() # Set the model to evaluation mode
tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)
# Query input
query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?")
# Alpha input
alpha = st.slider("Set Alpha (for dense and sparse vector balancing)", 0.0, 1.0, 0.5)
# Button to encode query and search the Pinecone index
if st.button("Search Query"):
if query_text and index:
# Encode query to get dense and sparse vectors
dense_vector = encode_query(model, query_text)
input_ids = tokenizer(query_text, return_tensors='pt')
with torch.no_grad():
sparse_vector = sparse_model('cpu'))['d_rep'].squeeze()
# Prepare sparse vector format for Pinecone
indices = sparse_vector.nonzero().squeeze().cpu().tolist()
values = sparse_vector[indices].cpu().tolist()
sparse_dict = {"indices": indices, "values": values}
# Scale dense and sparse vectors
hdense, hsparse = hybrid_scale(dense_vector, sparse_dict, alpha)
# Search the index
results = index.query(
for match in results.matches:
st.markdown(f"Score: {match.score:.4f}")
st.write(f"Answer: {match.metadata.get('context', 'No context available.')}")
st.error("Please enter a query and ensure the index is initialized.")