import os import streamlit as st from pymilvus import MilvusClient import torch from model import encode_dpr_question, get_dpr_encoder from model import summarize_text, get_summarizer from model import ask_reader, get_reader TITLE = 'ReSRer: Retriever-Summarizer-Reader' INITIAL = "What is the population of NYC" st.set_page_config(page_title=TITLE) st.header(TITLE) st.markdown('''
Ask short-answer question that can be find in Wikipedia data.
''', unsafe_allow_html=True) st.markdown('This demo retrieves the original context from 21,000,000 wikipedia passages in real-time') @st.cache_resource def load_models(): models = {} models['encoder'] = get_dpr_encoder() models['summarizer'] = get_summarizer() models['reader'] = get_reader() return models @st.cache_resource def load_client(): client = MilvusClient(user='resrer', password=os.environ['MILVUS_PW'], uri=f"http://{os.environ['MILVUS_HOST']}:19530", db_name='psgs_w100') return client client = load_client() models = load_models() styl = """ """ st.markdown(styl, unsafe_allow_html=True) question = st.text_input("Question", INITIAL) @torch.inference_mode() def main(question: str): if question in st.session_state: print("Cache hit!") ctx, summary, answer = st.session_state[question] else: print(f"Input: {question}") # Embedding question_vectors = encode_dpr_question( models['encoder'][0], models['encoder'][1], [question]) query_vector = question_vectors.detach().cpu().numpy().tolist()[0] # Retriever results = client.search(collection_name='dpr_nq', data=[ query_vector], limit=10, output_fields=['title', 'text']) texts = [result['entity']['text'] for result in results[0]] ctx = '\n'.join(texts) # Reader [summary] = summarize_text(models['summarizer'][0], models['summarizer'][1], [ctx]) answers = ask_reader(models['reader'][0], models['reader'][1], [question], [summary]) answer = answers[0]['answer'] print(f"\nAnswer: {answer}") st.session_state[question] = (ctx, summary, answer) # Summary st.write(f"### Answer: {answer}") st.markdown('
Summarized Context
', unsafe_allow_html=True) st.markdown( f"
{summary}

", unsafe_allow_html=True) st.markdown('
Original Context
', unsafe_allow_html=True) st.markdown(ctx) if question: main(question)