Spaces:
Running
Running
import os | |
import streamlit as st | |
from pymilvus import MilvusClient | |
import torch | |
from model import encode_dpr_question, get_dpr_encoder | |
from model import summarize_text, get_summarizer | |
from model import ask_reader, get_reader | |
TITLE = 'ReSRer: Retriever-Summarizer-Reader' | |
INITIAL = "What is the population of NYC" | |
st.set_page_config(page_title=TITLE) | |
st.header(TITLE) | |
st.markdown(''' | |
<h5>Ask short-answer question that can be find in Wikipedia data.</h5> | |
''', unsafe_allow_html=True) | |
st.markdown( | |
'This demo searches through 21,000,000 Wikipedia passages in real-time under the hood.') | |
def load_models(): | |
models = {} | |
models['encoder'] = get_dpr_encoder() | |
models['summarizer'] = get_summarizer() | |
models['reader'] = get_reader() | |
return models | |
def load_client(): | |
client = MilvusClient(user='resrer', password=os.environ['MILVUS_PW'], | |
uri=f"http://{os.environ['MILVUS_HOST']}:19530", db_name='psgs_w100') | |
return client | |
client = load_client() | |
models = load_models() | |
styl = """ | |
<style> | |
.StatusWidget-enter-done{ | |
position: fixed; | |
left: 50%; | |
top: 50%; | |
transform: translate(-50%, -50%); | |
} | |
.StatusWidget-enter-done button{ | |
display: none; | |
} | |
</style> | |
""" | |
st.markdown(styl, unsafe_allow_html=True) | |
question = st.text_input("Question", INITIAL) | |
col1, col2, col3 = st.columns(3) | |
if col1.button("What is the capital of South Korea"): | |
question = "What is the capital of South Korea" | |
if col2.button("What is the most famous building in Paris"): | |
question = "What is the most famous building in Paris" | |
if col3.button("Who is the actor of Harry Potter"): | |
question = "Who is the actor of Harry Potter" | |
def main(question: str): | |
if question in st.session_state: | |
print("Cache hit!") | |
ctx, summary, answer = st.session_state[question] | |
else: | |
print(f"Input: {question}") | |
# Embedding | |
question_vectors = encode_dpr_question( | |
models['encoder'][0], models['encoder'][1], [question]) | |
query_vector = question_vectors.detach().cpu().numpy().tolist()[0] | |
# Retriever | |
results = client.search(collection_name='dpr_nq', data=[ | |
query_vector], limit=10, output_fields=['title', 'text']) | |
texts = [result['entity']['text'] for result in results[0]] | |
ctx = '\n'.join(texts) | |
# Reader | |
[summary] = summarize_text(models['summarizer'][0], | |
models['summarizer'][1], [ctx]) | |
answers = ask_reader(models['reader'][0], | |
models['reader'][1], [question], [summary]) | |
answer = answers[0]['answer'] | |
print(f"\nAnswer: {answer}") | |
st.session_state[question] = (ctx, summary, answer) | |
# Summary | |
st.write(f"### Answer: {answer}") | |
st.markdown('<h5>Summarized Context</h5>', unsafe_allow_html=True) | |
st.markdown( | |
f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True) | |
st.markdown('<h5>Original Context</h5>', unsafe_allow_html=True) | |
st.markdown(ctx) | |
if question: | |
main(question) | |