datascientist22 commited on
Commit
d9e1771
1 Parent(s): 198dc13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -18
app.py CHANGED
@@ -3,18 +3,6 @@ from PyPDF2 import PdfReader
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
- # Initialize the tokenizer and model from the saved checkpoint
7
- tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
8
- model = AutoModelForCausalLM.from_pretrained(
9
- "himmeow/vi-gemma-2b-RAG",
10
- device_map="auto",
11
- torch_dtype=torch.bfloat16
12
- )
13
-
14
- # Use GPU if available
15
- if torch.cuda.is_available():
16
- model.to("cuda")
17
-
18
  # Set up the Streamlit app layout
19
  st.set_page_config(page_title="RAG PDF Chatbot", layout="wide")
20
 
@@ -58,6 +46,14 @@ submit_button = st.button("Submit")
58
  if 'chat_history' not in st.session_state:
59
  st.session_state.chat_history = []
60
 
 
 
 
 
 
 
 
 
61
  # Function to extract text from PDF files
62
  def extract_text_from_pdfs(files):
63
  text = ""
@@ -81,15 +77,11 @@ if submit_button and query:
81
  """
82
 
83
  # Encode the input text
84
- input_ids = tokenizer(prompt, return_tensors="pt")
85
-
86
- # Use GPU for input ids if available
87
- if torch.cuda.is_available():
88
- input_ids = input_ids.to("cuda")
89
 
90
  # Generate the response
91
  outputs = model.generate(
92
- **input_ids,
93
  max_new_tokens=500,
94
  no_repeat_ngram_size=5,
95
  )
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  # Set up the Streamlit app layout
7
  st.set_page_config(page_title="RAG PDF Chatbot", layout="wide")
8
 
 
46
  if 'chat_history' not in st.session_state:
47
  st.session_state.chat_history = []
48
 
49
+ # Load the tokenizer and model
50
+ tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
51
+ model = AutoModelForCausalLM.from_pretrained("himmeow/vi-gemma-2b-RAG")
52
+
53
+ # Use GPU if available
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+ model = model.to(device)
56
+
57
  # Function to extract text from PDF files
58
  def extract_text_from_pdfs(files):
59
  text = ""
 
77
  """
78
 
79
  # Encode the input text
80
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
 
 
 
81
 
82
  # Generate the response
83
  outputs = model.generate(
84
+ input_ids=input_ids,
85
  max_new_tokens=500,
86
  no_repeat_ngram_size=5,
87
  )