File size: 3,773 Bytes
e2b9039
 
3ecbeff
e2b9039
9a6f924
dd92eea
 
e2b9039
b21411e
9610c39
e2b9039
d792c52
 
 
b21411e
 
 
 
9610c39
 
b21411e
 
 
e2b9039
3ecbeff
 
 
 
9610c39
 
b21411e
3ecbeff
 
e2b9039
b21411e
 
 
e2b9039
9a6f924
 
 
 
 
 
 
 
 
 
 
b21411e
fff69e7
e2b9039
b21411e
 
9610c39
3ecbeff
b0a56a3
9610c39
b21411e
dd92eea
 
 
 
 
 
b0a56a3
b21411e
 
b0a56a3
9a6f924
 
 
b21411e
 
 
9a6f924
b21411e
dd92eea
9a6f924
febccb7
9a6f924
 
 
 
 
 
 
 
 
a19ad68
b21411e
9a6f924
 
9610c39
b21411e
 
 
a19ad68
9610c39
 
dd92eea
9610c39
b21411e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import streamlit as st
from pinecone import Pinecone
from sentence_transformers import SentenceTransformer
import torch
from splade.models.transformer_rep import Splade
from transformers import AutoTokenizer

# Title of the Streamlit App
st.title("Medical Hybrid Search")

# Initialize Pinecone globally
index = None

# Function to initialize Pinecone
def initialize_pinecone():
    api_key = os.getenv('PINECONE_API_KEY')  # Get Pinecone API key from environment variable
    if api_key:
        pc = Pinecone(api_key=api_key)
        return pc
    else:
        st.error("Pinecone API key not found! Please set the PINECONE_API_KEY environment variable.")
        return None

# Function to connect to the 'pubmed-splade' index
def connect_to_index(pc):
    index_name = 'pubmed-splade'  # Hardcoded index name
    if index_name in pc.list_indexes().names():
        index = pc.Index(index_name)
        return index
    else:
        st.error(f"Index '{index_name}' not found!")
        return None

# Function to encode query using sentence transformers model
def encode_query(model, query_text):
    return model.encode(query_text).tolist()

# Function to create hybrid scaled vectors
def hybrid_scale(dense, sparse, alpha):
    if alpha < 0 or alpha > 1:
        raise ValueError("Alpha must be between 0 and 1")
    hsparse = {
        'indices': sparse['indices'],
        'values': [v * (1 - alpha) for v in sparse['values']]
    }
    hdense = [v * alpha for v in dense]
    return hdense, hsparse

# Initialize Pinecone
pc = initialize_pinecone()

# If Pinecone initialized successfully, proceed with index management
if pc:
    # Connect directly to 'pubmed-splade' index
    index = connect_to_index(pc)

    # Model for query encoding
    model = SentenceTransformer('msmarco-bert-base-dot-v5')
    
    # Initialize sparse model and tokenizer
    sparse_model_id = 'naver/splade-cocondenser-ensembledistil'
    sparse_model = Splade(sparse_model_id, agg='max')
    sparse_model.eval()  # Set the model to evaluation mode
    tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)

    # Query input
    query_text = st.text_input("Enter a Query to Search", "Can clinicians use the PHQ-9 to assess depression?")
    
    # Alpha input
    alpha = st.slider("Set Alpha (for dense and sparse vector balancing)", 0.0, 1.0, 0.5)

    # Button to encode query and search the Pinecone index
    if st.button("Search Query"):
        if query_text and index:
            # Encode query to get dense and sparse vectors
            dense_vector = encode_query(model, query_text)
            input_ids = tokenizer(query_text, return_tensors='pt')
            with torch.no_grad():
                sparse_vector = sparse_model(d_kwargs=input_ids.to('cpu'))['d_rep'].squeeze()

            # Prepare sparse vector format for Pinecone
            indices = sparse_vector.nonzero().squeeze().cpu().tolist()
            values = sparse_vector[indices].cpu().tolist()
            sparse_dict = {"indices": indices, "values": values}

            # Scale dense and sparse vectors
            hdense, hsparse = hybrid_scale(dense_vector, sparse_dict, alpha)

            # Search the index
            results = index.query(
                vector=hdense,
                sparse_vector=hsparse,
                top_k=3,
                include_metadata=True
            )
            
            st.write("### Search Results:")
            for match in results.matches:
                st.markdown(f"#### Score: **{match.score:.4f}**")
                st.write(f"####Context:{match.metadata.get('context', 'No context available.')}")
                st.write("---")
        else:
            st.error("Please enter a query and ensure the index is initialized.")