seonglae commited on
Commit
76ce883
1 Parent(s): 068bab1

feat: cpu support using cuda condition branch

Browse files
Files changed (2) hide show
  1. app.py +2 -0
  2. model.py +34 -14
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
 
3
  import streamlit as st
4
  from pymilvus import MilvusClient
 
5
 
6
  from model import encode_dpr_question, get_dpr_encoder
7
  from model import summarize_text, get_summarizer
@@ -56,6 +57,7 @@ st.markdown(styl, unsafe_allow_html=True)
56
  question = st.text_area("Text to summarize", INITIAL, height=400)
57
 
58
 
 
59
  def main(question: str):
60
  if question in st.session_state:
61
  print("Cache hit!")
 
2
 
3
  import streamlit as st
4
  from pymilvus import MilvusClient
5
+ import torch
6
 
7
  from model import encode_dpr_question, get_dpr_encoder
8
  from model import summarize_text, get_summarizer
 
57
  question = st.text_area("Text to summarize", INITIAL, height=400)
58
 
59
 
60
+ @torch.inference_mode()
61
  def main(question: str):
62
  if question in st.session_state:
63
  print("Cache hit!")
model.py CHANGED
@@ -7,15 +7,21 @@ from transformers import QuestionAnsweringPipeline
7
  from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
8
  import torch
9
 
 
10
  max_answer_len = 8
11
  logging.set_verbosity_error()
12
 
13
 
 
14
  def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
15
  input_texts: List[str]):
16
  inputs = tokenizer(input_texts, padding=True,
17
- return_tensors='pt', truncation=True).to(1)
18
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
 
 
 
 
19
  summary_ids = model.generate(inputs["input_ids"])
20
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
21
  clean_up_tokenization_spaces=False, batch_size=len(input_texts))
@@ -24,14 +30,13 @@ def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditiona
24
 
25
  def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
26
  tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
27
- model = PegasusXForConditionalGeneration.from_pretrained(model_id).to(1)
 
 
28
  model = torch.compile(model)
29
  return tokenizer, model
30
 
31
 
32
- # OpenAI reader
33
-
34
-
35
  class AnswerInfo(TypedDict):
36
  score: float
37
  start: int
@@ -42,10 +47,16 @@ class AnswerInfo(TypedDict):
42
  @torch.inference_mode()
43
  def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
44
  questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
45
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
 
 
 
 
 
 
46
  pipeline = QuestionAnsweringPipeline(
47
- model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
48
- answer_infos: List[AnswerInfo] = pipeline(
49
  question=questions, context=ctxs)
50
  for answer_info in answer_infos:
51
  answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
@@ -54,10 +65,13 @@ def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
54
 
55
  def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"):
56
  tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
57
- model = DPRReader.from_pretrained(model_id).to(0)
 
 
58
  return tokenizer, model
59
 
60
 
 
61
  def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
62
  """Encode a question using DPR question encoder.
63
  https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
@@ -67,9 +81,13 @@ def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuesti
67
  model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
68
  """
69
  batch_dict = tokenizer(questions, return_tensors="pt",
70
- padding=True, truncation=True,).to(0)
71
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
72
- embeddings: torch.FloatTensor = model(**batch_dict).pooler_output
 
 
 
 
73
  return embeddings
74
 
75
 
@@ -82,5 +100,7 @@ def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") ->
82
  model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
83
  """
84
  tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
85
- model = DPRQuestionEncoder.from_pretrained(model_id).to(0)
 
 
86
  return tokenizer, model
 
7
  from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
8
  import torch
9
 
10
+ cuda = torch.cuda.is_available()
11
  max_answer_len = 8
12
  logging.set_verbosity_error()
13
 
14
 
15
+ @torch.inference_mode()
16
  def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
17
  input_texts: List[str]):
18
  inputs = tokenizer(input_texts, padding=True,
19
+ return_tensors='pt', truncation=True)
20
+ if cuda:
21
+ inputs = inputs.to(0)
22
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
23
+ summary_ids = model.generate(inputs["input_ids"])
24
+ else:
25
  summary_ids = model.generate(inputs["input_ids"])
26
  summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
27
  clean_up_tokenization_spaces=False, batch_size=len(input_texts))
 
30
 
31
  def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
32
  tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
33
+ model = PegasusXForConditionalGeneration.from_pretrained(model_id)
34
+ if cuda:
35
+ model = model.to(0)
36
  model = torch.compile(model)
37
  return tokenizer, model
38
 
39
 
 
 
 
40
  class AnswerInfo(TypedDict):
41
  score: float
42
  start: int
 
47
  @torch.inference_mode()
48
  def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
49
  questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
50
+ if cuda:
51
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
52
+ pipeline = QuestionAnsweringPipeline(
53
+ model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
54
+ answer_infos: List[AnswerInfo] = pipeline(
55
+ question=questions, context=ctxs)
56
+ else:
57
  pipeline = QuestionAnsweringPipeline(
58
+ model=model, tokenizer=tokenizer, device='cpu', max_answer_len=max_answer_len)
59
+ answer_infos = pipeline(
60
  question=questions, context=ctxs)
61
  for answer_info in answer_infos:
62
  answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
 
65
 
66
  def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"):
67
  tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
68
+ model = DPRReader.from_pretrained(model_id)
69
+ if cuda:
70
+ model = model.to(0)
71
  return tokenizer, model
72
 
73
 
74
+ @torch.inference_mode()
75
  def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
76
  """Encode a question using DPR question encoder.
77
  https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
 
81
  model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
82
  """
83
  batch_dict = tokenizer(questions, return_tensors="pt",
84
+ padding=True, truncation=True)
85
+ if cuda:
86
+ batch_dict = batch_dict.to(0)
87
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
88
+ embeddings: torch.FloatTensor = model(**batch_dict).pooler_output
89
+ else:
90
+ embeddings = model(**batch_dict).pooler_output
91
  return embeddings
92
 
93
 
 
100
  model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
101
  """
102
  tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
103
+ model = DPRQuestionEncoder.from_pretrained(model_id)
104
+ if cuda:
105
+ model = model.to(0)
106
  return tokenizer, model