Spaces:
Running
Running
feat: huggingface space pipeline with resrer model
Browse files- app.py +97 -2
- model.py +86 -0
- requirements.txt +3 -0
app.py
CHANGED
@@ -1,2 +1,97 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from pymilvus import MilvusClient
|
5 |
+
|
6 |
+
from model import encode_dpr_question, get_dpr_encoder
|
7 |
+
from model import summarize_text, get_summarizer
|
8 |
+
from model import ask_reader, get_reader
|
9 |
+
|
10 |
+
|
11 |
+
TITLE = 'ReSRer: Retriever-Summarizer-Reader'
|
12 |
+
INITIAL = "What is the population of NYC"
|
13 |
+
|
14 |
+
st.set_page_config(page_title=TITLE)
|
15 |
+
st.header(TITLE)
|
16 |
+
st.markdown('''
|
17 |
+
### Ask short-answer question that can be find in Wikipedia data.
|
18 |
+
''', unsafe_allow_html=True)
|
19 |
+
|
20 |
+
|
21 |
+
@st.cache_resource
|
22 |
+
def load_models():
|
23 |
+
models = {}
|
24 |
+
models['encoder'] = get_dpr_encoder()
|
25 |
+
models['summarizer'] = get_summarizer()
|
26 |
+
models['reader'] = get_reader()
|
27 |
+
return models
|
28 |
+
|
29 |
+
|
30 |
+
@st.cache_resource
|
31 |
+
def load_client():
|
32 |
+
client = MilvusClient(user='resrer', password=os.env['MILVUS_PW'],
|
33 |
+
uri=f"http://{os.env['MILVUS_HOST']}:19530", db_name='psgs_w100')
|
34 |
+
return client
|
35 |
+
|
36 |
+
|
37 |
+
client = load_client()
|
38 |
+
models = load_models()
|
39 |
+
|
40 |
+
styl = """
|
41 |
+
<style>
|
42 |
+
.StatusWidget-enter-done{
|
43 |
+
position: fixed;
|
44 |
+
left: 50%;
|
45 |
+
top: 50%;
|
46 |
+
transform: translate(-50%, -50%);
|
47 |
+
}
|
48 |
+
.StatusWidget-enter-done button{
|
49 |
+
display: none;
|
50 |
+
}
|
51 |
+
</style>
|
52 |
+
"""
|
53 |
+
st.markdown(styl, unsafe_allow_html=True)
|
54 |
+
|
55 |
+
|
56 |
+
question = st.text_area("Text to summarize", INITIAL, height=400)
|
57 |
+
|
58 |
+
|
59 |
+
def main(question: str):
|
60 |
+
if question in st.session_state:
|
61 |
+
print("Cache hit!")
|
62 |
+
ctx, summary, answer = st.session_state[question]
|
63 |
+
else:
|
64 |
+
print(f"Input: {question}")
|
65 |
+
# Embedding
|
66 |
+
question_vectors = encode_dpr_question(
|
67 |
+
models['encoder'][0], models['encoder'][1], [question])
|
68 |
+
query_vector = question_vectors.detach().cpu().numpy().tolist()[0]
|
69 |
+
|
70 |
+
# Retriever
|
71 |
+
results = client.search(collection_name='dpr_nq', data=[
|
72 |
+
query_vector], limit=10, output_fields=['title', 'text'])
|
73 |
+
texts = [result['entity']['text'] for result in results[0]]
|
74 |
+
ctx = '\n'.join(texts)
|
75 |
+
|
76 |
+
# Reader
|
77 |
+
summary = summarize_text(models['summarizer'][0],
|
78 |
+
models['summarizer'][1], [summary])
|
79 |
+
answers = ask_reader(models['reader'][0],
|
80 |
+
models['reader'][1], [question], [ctx])
|
81 |
+
answer = answers[0]['answer']
|
82 |
+
print(f"\nAnswer: {answer}")
|
83 |
+
|
84 |
+
st.session_state[question] = (ctx, summary, answer)
|
85 |
+
|
86 |
+
# Summary
|
87 |
+
st.markdown(answer)
|
88 |
+
st.write("## Summary")
|
89 |
+
st.markdown(
|
90 |
+
f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True)
|
91 |
+
st.markdown(ctx)
|
92 |
+
|
93 |
+
st.write(f"{question}", unsafe_allow_html=True)
|
94 |
+
|
95 |
+
|
96 |
+
if question:
|
97 |
+
main(question)
|
model.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, TypedDict
|
2 |
+
from re import sub
|
3 |
+
|
4 |
+
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging
|
5 |
+
from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader
|
6 |
+
from transformers import QuestionAnsweringPipeline
|
7 |
+
from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
|
8 |
+
import torch
|
9 |
+
|
10 |
+
max_answer_len = 8
|
11 |
+
logging.set_verbosity_error()
|
12 |
+
|
13 |
+
|
14 |
+
def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
|
15 |
+
input_texts: List[str]):
|
16 |
+
inputs = tokenizer(input_texts, padding=True,
|
17 |
+
return_tensors='pt', truncation=True).to(1)
|
18 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
19 |
+
summary_ids = model.generate(inputs["input_ids"])
|
20 |
+
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
|
21 |
+
clean_up_tokenization_spaces=False, batch_size=len(input_texts))
|
22 |
+
return summaries
|
23 |
+
|
24 |
+
|
25 |
+
def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
|
26 |
+
tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
|
27 |
+
model = PegasusXForConditionalGeneration.from_pretrained(model_id).to(1)
|
28 |
+
model = torch.compile(model)
|
29 |
+
return tokenizer, model
|
30 |
+
|
31 |
+
|
32 |
+
# OpenAI reader
|
33 |
+
|
34 |
+
|
35 |
+
class AnswerInfo(TypedDict):
|
36 |
+
score: float
|
37 |
+
start: int
|
38 |
+
end: int
|
39 |
+
answer: str
|
40 |
+
|
41 |
+
|
42 |
+
@torch.inference_mode()
|
43 |
+
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
|
44 |
+
questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
|
45 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
46 |
+
pipeline = QuestionAnsweringPipeline(
|
47 |
+
model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
|
48 |
+
answer_infos: List[AnswerInfo] = pipeline(
|
49 |
+
question=questions, context=ctxs)
|
50 |
+
for answer_info in answer_infos:
|
51 |
+
answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
|
52 |
+
return answer_infos
|
53 |
+
|
54 |
+
|
55 |
+
def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"):
|
56 |
+
tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
|
57 |
+
model = DPRReader.from_pretrained(model_id).to(0)
|
58 |
+
return tokenizer, model
|
59 |
+
|
60 |
+
|
61 |
+
def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
|
62 |
+
"""Encode a question using DPR question encoder.
|
63 |
+
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
|
64 |
+
|
65 |
+
Args:
|
66 |
+
question (str): question string to encode
|
67 |
+
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
|
68 |
+
"""
|
69 |
+
batch_dict = tokenizer(questions, return_tensors="pt",
|
70 |
+
padding=True, truncation=True,).to(0)
|
71 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
72 |
+
embeddings: torch.FloatTensor = model(**batch_dict).pooler_output
|
73 |
+
return embeddings
|
74 |
+
|
75 |
+
|
76 |
+
def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]:
|
77 |
+
"""Encode a question using DPR question encoder.
|
78 |
+
https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
|
79 |
+
|
80 |
+
Args:
|
81 |
+
question (str): question string to encode
|
82 |
+
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
|
83 |
+
"""
|
84 |
+
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
|
85 |
+
model = DPRQuestionEncoder.from_pretrained(model_id).to(0)
|
86 |
+
return tokenizer, model
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
pymilvus
|