File size: 2,723 Bytes
75148a1
 
 
 
76ce883
75148a1
 
 
 
 
 
 
 
 
 
 
 
26e6019
75148a1
d4f2948
75148a1
 
 
 
 
 
 
 
 
 
 
 
 
068bab1
 
75148a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26e6019
75148a1
 
76ce883
75148a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a11e20
ad19f21
75148a1
ad19f21
75148a1
 
 
 
 
 
c5fe6e8
 
75148a1
 
c5fe6e8
75148a1
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os

import streamlit as st
from pymilvus import MilvusClient
import torch

from model import encode_dpr_question, get_dpr_encoder
from model import summarize_text, get_summarizer
from model import ask_reader, get_reader


TITLE = 'ReSRer: Retriever-Summarizer-Reader'
INITIAL = "What is the population of NYC"

st.set_page_config(page_title=TITLE)
st.header(TITLE)
st.markdown('''
<h5>Ask short-answer question that can be find in Wikipedia data.</h5>
''', unsafe_allow_html=True)
st.markdown('This demo retrieves the original context from 21,000,000 wikipedia passages in real-time')


@st.cache_resource
def load_models():
  models = {}
  models['encoder'] = get_dpr_encoder()
  models['summarizer'] = get_summarizer()
  models['reader'] = get_reader()
  return models


@st.cache_resource
def load_client():
  client = MilvusClient(user='resrer', password=os.environ['MILVUS_PW'],
                        uri=f"http://{os.environ['MILVUS_HOST']}:19530", db_name='psgs_w100')
  return client


client = load_client()
models = load_models()

styl = """
<style>
    .StatusWidget-enter-done{
      position: fixed;
      left: 50%;
      top: 50%;
      transform: translate(-50%, -50%);
    }
    .StatusWidget-enter-done button{
      display: none;
    }
</style>
"""
st.markdown(styl, unsafe_allow_html=True)


question = st.text_input("Question", INITIAL)


@torch.inference_mode()
def main(question: str):
  if question in st.session_state:
    print("Cache hit!")
    ctx, summary, answer = st.session_state[question]
  else:
    print(f"Input: {question}")
    # Embedding
    question_vectors = encode_dpr_question(
        models['encoder'][0], models['encoder'][1], [question])
    query_vector = question_vectors.detach().cpu().numpy().tolist()[0]

    # Retriever
    results = client.search(collection_name='dpr_nq', data=[
        query_vector], limit=10, output_fields=['title', 'text'])
    texts = [result['entity']['text'] for result in results[0]]
    ctx = '\n'.join(texts)

    # Reader
    [summary] = summarize_text(models['summarizer'][0],
                             models['summarizer'][1], [ctx])
    answers = ask_reader(models['reader'][0],
                         models['reader'][1], [question], [summary])
    answer = answers[0]['answer']
    print(f"\nAnswer: {answer}")

    st.session_state[question] = (ctx, summary, answer)

  # Summary
  st.write(f"### Answer: {answer}")
  st.markdown('<h5>Summarized Context</h5>', unsafe_allow_html=True)
  st.markdown(
      f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True)
  st.markdown('<h5>Original Context</h5>', unsafe_allow_html=True)
  st.markdown(ctx)


if question:
  main(question)