File size: 2,833 Bytes
e2b9039
 
 
 
 
 
 
 
 
fcee9ac
e2b9039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import streamlit as st
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from splade.models.transformer_rep import Splade
import pinecone

# Initialize Pinecone connection
api_key = os.getenv('PINECONE_API_KEY')
pinecone.init(api_key=api_key, environment='us-east1-gcp')
index_name = 'pubmed-splade'

# Connect to the Pinecone index
if pinecone.list_indexes() and index_name in pinecone.list_indexes():
    index = pinecone.Index(index_name)
else:
    st.error("Pinecone index not found! Ensure the correct Pinecone index is being used.")

# Initialize Dense and Sparse models
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Dense model (Sentence-BERT)
dense_model = SentenceTransformer('msmarco-bert-base-dot-v5', device=device)

# Sparse model (SPLADE)
sparse_model_id = 'naver/splade-cocondenser-ensembledistil'
sparse_model = Splade(sparse_model_id, agg='max').to(device)
sparse_model.eval()

# Tokenizer for sparse model
tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)

# Function to encode query into dense and sparse vectors
def encode(text: str):
    # Dense vector
    dense_vec = dense_model.encode(text).tolist()
    
    # Sparse vector
    input_ids = tokenizer(text, return_tensors='pt')
    with torch.no_grad():
        sparse_vec = sparse_model(d_kwargs=input_ids.to(device))['d_rep'].squeeze()
    
    # Extract non-zero values and indices for sparse vector
    indices = sparse_vec.nonzero().squeeze().cpu().tolist()
    values = sparse_vec[indices].cpu().tolist()
    
    sparse_dict = {"indices": indices, "values": values}
    return dense_vec, sparse_dict

# Function for hybrid search scaling
def hybrid_scale(dense, sparse, alpha: float):
    if alpha < 0 or alpha > 1:
        raise ValueError("Alpha must be between 0 and 1")
    
    hsparse = {
        'indices': sparse['indices'],
        'values':  [v * (1 - alpha) for v in sparse['values']]
    }
    hdense = [v * alpha for v in dense]
    
    return hdense, hsparse

# Streamlit UI
st.title("PubMed Search Application")
query = st.text_input("Enter your query:", "")

# Slider to control sparse-dense scaling
alpha = st.slider("Hybrid Search Weight (Dense vs Sparse)", 0.0, 1.0, 0.5)

if query:
    # Encode the query
    dense_vec, sparse_vec = encode(query)
    
    # Scale vectors based on slider value
    hdense, hsparse = hybrid_scale(dense_vec, sparse_vec, alpha)
    
    # Query Pinecone index
    response = index.query(vector=hdense, sparse_vector=hsparse, top_k=3, include_metadata=True)
    
    # Display results
    st.write(f"Top results for query: **{query}**")
    for match in response['matches']:
        st.write(f"**Score**: {match['score']}")
        st.write(f"**Context**: {match['metadata']['context']}")
        st.write("---")