Binaryy commited on
Commit
e700a38
1 Parent(s): 50c3b73

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pinecone
3
+ from sentence_transformers import SentenceTransformer
4
+ import torch
5
+ from splade.models.transformer_rep import Splade
6
+ from transformers import AutoTokenizer
7
+ from datasets import load_dataset
8
+
9
+ pinecone.init(
10
+ api_key='884344f6-d820-4bc8-9edf-4157373df452',
11
+ environment='gcp-starter'
12
+ )
13
+ index = pinecone.Index('pubmed-splade')
14
+
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ # check device being run on
17
+ if device != 'cuda':
18
+ print("==========\n"+
19
+ "WARNING: You are not running on GPU so this may be slow.\n"+
20
+ "\n==========")
21
+
22
+ dense_model = SentenceTransformer(
23
+ 'msmarco-bert-base-dot-v5',
24
+ device=device
25
+ )
26
+
27
+ sparse_model_id = 'naver/splade-cocondenser-ensembledistil'
28
+ sparse_model = Splade(sparse_model_id, agg='max')
29
+ sparse_model.to(device) # move to GPU if possible
30
+ sparse_model.eval()
31
+
32
+ tokenizer = AutoTokenizer.from_pretrained(sparse_model_id)
33
+ data = load_dataset('Binaryy/cream_listings', split='train')
34
+ df = data.to_pandas()
35
+
36
+ def encode(text: str):
37
+ # create dense vec
38
+ dense_vec = dense_model.encode(text).tolist()
39
+ # create sparse vec
40
+ input_ids = tokenizer(text, return_tensors='pt')
41
+ with torch.no_grad():
42
+ sparse_vec = sparse_model(
43
+ d_kwargs=input_ids.to(device)
44
+ )['d_rep'].squeeze()
45
+ # convert to dictionary format
46
+ indices = sparse_vec.nonzero().squeeze().cpu().tolist()
47
+ values = sparse_vec[indices].cpu().tolist()
48
+ sparse_dict = {"indices": indices, "values": values}
49
+ # return vecs
50
+ return dense_vec, sparse_dict
51
+
52
+ def search(query):
53
+ dense, sparse = encode(query)
54
+ # query
55
+ xc = index.query(
56
+ vector=dense,
57
+ sparse_vector=sparse,
58
+ top_k=5, # how many results to return
59
+ include_metadata=True
60
+ )
61
+ match_ids = [match['id'].split('-')[0] for match in xc['matches']]
62
+ # Query the existing DataFrame based on 'id'
63
+ filtered_df = df[df['_id'].isin(match_ids)]
64
+ attributes_to_extract = ['_id', 'postedBy.accountName', 'images', 'title', 'location', 'price']
65
+ extracted_data = filtered_df[attributes_to_extract]
66
+ result_json = extracted_data.to_json(orient='records')
67
+ return result_json
68
+
69
+ # Create a Gradio interface
70
+ iface = gr.Interface(
71
+ fn=search,
72
+ inputs="text",
73
+ outputs="json",
74
+ title="Semantic Search Prototype",
75
+ description="Enter your query to perform a semantic search.",
76
+ )
77
+
78
+ # Launch the Gradio interface
79
+ iface.launch(share=True)