secilozksen's picture
base model discarded
2d178ec
raw
history blame
11.3 kB
import copy
import streamlit as st
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from sentence_transformers.cross_encoder import CrossEncoder
from st_aggrid import GridOptionsBuilder, AgGrid
import pickle
import torch
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
import base64
import io
st.set_page_config(layout="wide")
DATAFRAME_FILE_ORIGINAL = 'policyQA_original.csv'
DATAFRAME_FILE_BSBS = 'basecamp.csv'
selectbox_selections = {
'Retrieve - Rerank (with fine-tuned cross-encoder)': 1,
'Dense Passage Retrieval':2,
# 'Base Dense Passage Retrieval': 3,
'Retrieve - Rerank':4
}
imagebox_selections = {
'Retrieve - Rerank (with fine-tuned cross-encoder)': 'Retrieve-rerank-trained-cross-encoder.png',
'Dense Passage Retrieval': 'DPR_pipeline.png',
'Base Dense Passage Retrieval': 'base-dpr.png',
'Retrieve - Rerank': 'retrieve-rerank.png'
}
def retrieve_rerank(question):
# Semantic Search (Retrieve)
question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, context_embeddings, top_k=100)
if len(hits) == 0:
return []
hits = hits[0]
# Rerank - score all retrieved passages with cross-encoder
cross_inp = [[question, contexes[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Output of top-5 hits from re-ranker
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
top_5_contexes = []
top_5_scores = []
for hit in hits[0:20]:
top_5_contexes.append(contexes[hit['corpus_id']])
top_5_scores.append(hit['cross-score'])
return top_5_contexes, top_5_scores
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else: return super().find_class(module, name)
@st.cache(show_spinner=False, allow_output_mutation=True)
def load_paragraphs(path):
with open(path, "rb") as fIn:
cache_data = CPU_Unpickler(fIn).load()
corpus_sentences = cache_data['contexes']
corpus_embeddings = cache_data['embeddings']
return corpus_embeddings, corpus_sentences
@st.cache(show_spinner=False)
def load_dataframes():
# data_original = pd.read_csv(DATAFRAME_FILE_ORIGINAL, index_col=0, sep='|')
data_bsbs = pd.read_csv(DATAFRAME_FILE_BSBS, index_col=0, sep='|')
data_bsbs.drop(['context_id', 'answer', 'answer_start', 'answer_end'], axis=1, inplace=True)
# data_original = data_original.sample(frac=1).reset_index(drop=True)
data_bsbs = data_bsbs.sample(frac=1).reset_index(drop=True)
return data_bsbs
def dot_product(question_output, context_output):
mat1 = torch.squeeze(question_output, 0)
mat2 = torch.squeeze(context_output, 0)
result = torch.dot(mat1, mat2)
return result
def base_dpr_pipeline(question):
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt")
question_embedding = base_dpr_context_encoder(**tokenized_question)
question_embedding = mean_pooling(question_embedding[0], tokenized_question['attention_mask'])
# question_embedding = question_embedding['pooler_output']
results_list = []
for i, context_embedding in enumerate(base_dpr_context_embeddings):
score = dot_product(question_embedding, context_embedding)
results_list.append(score.detach().cpu())
hits = sorted(range(len(results_list)), key=lambda i: results_list[i], reverse=True)
top_5_contexes = []
top_5_scores = []
for j in hits[0:5]:
top_5_contexes.append(base_contexes[j])
top_5_scores.append(results_list[j])
return top_5_contexes, top_5_scores
def search_pipeline(question, search_method):
if search_method == 1: #Retrieve - rerank with fine-tuned cross encoder
return retrieve_rerank_with_trained_cross_encoder(question)
if search_method == 2:
return custom_dpr_pipeline(question) # DPR only
# if search_method == 3:
# return base_dpr_pipeline(question) # DPR only
if search_method == 4:
return retrieve_rerank(question)
def mean_pooling(token_embeddings, mask):
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
return sentence_embeddings
def custom_dpr_pipeline(question):
#paragraphs
tokenized_question = question_tokenizer(question, padding=True, truncation=True, return_tensors="pt")
question_embedding = dpr_trained.model.question_model(**tokenized_question)
question_embedding = mean_pooling(question_embedding[0], tokenized_question['attention_mask'])
# question_embedding = question_embedding['pooler_output']
results_list = []
for i,context_embedding in enumerate(dpr_context_embeddings):
score = dot_product(question_embedding, context_embedding)
results_list.append(score.detach().cpu())
hits = sorted(range(len(results_list)), key=lambda i: results_list[i], reverse=True)
top_5_contexes = []
top_5_scores = []
for j in hits[0:5]:
top_5_contexes.append(dpr_contexes[j])
top_5_scores.append(results_list[j])
return top_5_contexes, top_5_scores
def retrieve(question):
# Semantic Search (Retrieve)
question_embedding = bi_encoder.encode(question, convert_to_tensor=True)
hits = util.semantic_search(question_embedding, context_embeddings, top_k=100)
if len(hits) == 0:
return []
hits = hits[0]
return hits
def retrieve_rerank_with_trained_cross_encoder(question):
hits = retrieve(question)
cross_inp = [(question, contexes[hit['corpus_id']]) for hit in hits]
cross_scores = trained_cross_encoder.predict(cross_inp)
# Sort results by the cross-encoder scores
for idx in range(len(cross_scores)):
hits[idx]['cross-score'] = cross_scores[idx]
# Output of top-5 hits from re-ranker
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
top_5_contexes = []
top_5_scores = []
for hit in hits[0:5]:
top_5_contexes.append(contexes[hit['corpus_id']])
top_5_scores.append(hit['cross-score'])
return top_5_contexes, top_5_scores
def interactive_table(dataframe):
gb = GridOptionsBuilder.from_dataframe(dataframe)
gb.configure_pagination(paginationAutoPageSize=True)
gb.configure_side_bar()
gb.configure_selection('single', rowMultiSelectWithClick=True,
groupSelectsChildren="Group checkbox select children") # Enable multi-row selection
gridOptions = gb.build()
grid_response = AgGrid(
dataframe,
gridOptions=gridOptions,
data_return_mode='AS_INPUT',
update_mode='SELECTION_CHANGED',
enable_enterprise_modules=False,
fit_columns_on_grid_load=False,
theme='streamlit', # Add theme color to the table
height=350,
width='100%',
reload_data=False
)
return grid_response
def img_to_bytes(img_path):
img_bytes = Path(img_path).read_bytes()
encoded = base64.b64encode(img_bytes).decode()
return encoded
def qa_main_widgetsv2():
st.title("Question Answering Demo")
st.markdown("""---""")
option = st.selectbox("Select a search method:", list(selectbox_selections.keys()))
header_html = "<center> <img src='data:image/png;base64,{}' class='img-fluid' width='60%', height='40%'> </center>".format(
img_to_bytes(imagebox_selections[option])
)
st.markdown(
header_html, unsafe_allow_html=True,
)
st.markdown("""---""")
col1, col3 = st.columns([1, 1])
with col1:
form = st.form(key='first_form')
question = form.text_area("What is your question?:", height=200)
submit = form.form_submit_button('Submit')
if "form_submit" not in st.session_state:
st.session_state.form_submit = False
if submit:
st.session_state.form_submit = True
if st.session_state.form_submit and question != '':
with st.spinner(text='Related context search in progress..'):
top_5_contexes, top_5_scores = search_pipeline(question.strip(), selectbox_selections[option])
if len(top_5_contexes) == 0:
st.error("Related context not found!")
st.session_state.form_submit = False
else:
for i, context in enumerate(top_5_contexes):
st.markdown(f"## Related Context - {i + 1} (score: {top_5_scores[i]:.2f})")
st.markdown(context)
st.markdown("""---""")
with col3:
st.markdown("## Our Questions")
grid_response = interactive_table(dataframe_bsbs)
data2 = grid_response['selected_rows']
if "grid_click_2" not in st.session_state:
st.session_state.grid_click_2 = False
if len(data2) > 0:
st.session_state.grid_click_2 = True
if st.session_state.grid_click_2:
selection = data2[0]
# st.markdown("## Context & Answer:")
st.markdown("### Context:")
st.write(selection['context'])
st.markdown("### Question:")
st.write(selection['question'])
# st.markdown("### Answer:")
# st.write(selection['answer'])
st.session_state.grid_click_2 = False
@st.cache(show_spinner=False, allow_output_mutation = True)
def load_models(dpr_model_path, auth_token, cross_encoder_model_path):
dpr_trained = AutoModel.from_pretrained(dpr_model_path, use_auth_token=auth_token,
trust_remote_code=True)
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
bi_encoder.max_seq_length = 500
trained_cross_encoder = CrossEncoder(cross_encoder_model_path)
base_dpr_context_encoder = AutoModel.from_pretrained('facebook/contriever-msmarco')
question_tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
return dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer, base_dpr_context_encoder
context_embeddings, contexes = load_paragraphs('st-context-embeddings.pkl')
dpr_context_embeddings, dpr_contexes = load_paragraphs('basecamp-dpr-contriever-embeddings.pkl')
#base_dpr_context_embeddings, base_contexes = load_paragraphs('basecamp-base-dpr-contriever-embeddings.pkl')
dataframe_bsbs = load_dataframes()
dpr_trained, bi_encoder, cross_encoder, trained_cross_encoder, question_tokenizer, base_dpr_context_encoder = copy.deepcopy(load_models(st.secrets["DPR_MODEL_PATH"], st.secrets["AUTH_TOKEN"], st.secrets["CROSS_ENCODER_MODEL_PATH"]))
qa_main_widgetsv2()
#if __name__ == '__main__':
# top_5_contexes, top_5_scores = search_pipeline('What contributions does 37Signals make to open-source projects?', 3)