datascientist22 commited on
Commit
e93e1aa
β€’
1 Parent(s): 2459f29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -79
app.py CHANGED
@@ -3,48 +3,7 @@ from PyPDF2 import PdfReader
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
- # App configuration
7
- st.set_page_config(page_title="PDF Chatbot", layout="wide")
8
- st.markdown(
9
- """
10
- <style>
11
- body {
12
- background: linear-gradient(90deg, rgba(255,224,230,1) 0%, rgba(224,255,255,1) 50%, rgba(224,240,255,1) 100%);
13
- color: #000;
14
- }
15
- </style>
16
- """,
17
- unsafe_allow_html=True
18
- )
19
-
20
- # Title and "Created by" section
21
- st.markdown("<h1 style='text-align: center; color: #FF69B4;'>πŸ“„ PDF RAG Chatbot</h1>", unsafe_allow_html=True)
22
- st.markdown(
23
- "<h4 style='text-align: center;'>Created by: <a href='https://www.linkedin.com/in/datascientisthameshraj/' style='color:#FF4500;'>Engr. Hamesh Raj</a></h4>",
24
- unsafe_allow_html=True
25
- )
26
-
27
- # Sidebar for PDF file upload
28
- uploaded_files = st.sidebar.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
29
-
30
- # Query input
31
- query = st.text_input("Ask a question about the uploaded PDFs:")
32
-
33
- # Initialize session state to store conversation
34
- if "conversation" not in st.session_state:
35
- st.session_state.conversation = []
36
-
37
- # Function to extract text from PDFs
38
- def extract_text_from_pdfs(files):
39
- pdf_text = ""
40
- for file in files:
41
- reader = PdfReader(file)
42
- for page_num in range(len(reader.pages)):
43
- page = reader.pages[page_num]
44
- pdf_text += page.extract_text() + "\n"
45
- return pdf_text
46
-
47
- # Load model and tokenizer
48
  @st.cache_resource
49
  def load_model():
50
  tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
@@ -57,43 +16,80 @@ def load_model():
57
  model.to("cuda")
58
  return tokenizer, model
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Process and respond to user query
61
  if st.button("Submit"):
62
  if uploaded_files and query:
63
- pdf_text = extract_text_from_pdfs(uploaded_files)
64
- tokenizer, model = load_model()
65
-
66
- prompt = """
67
- ### Instruction and Input:
68
- Based on the following context/document:
69
- {}
70
- Please answer the question: {}
71
-
72
- ### Response:
73
- {}
74
- """
75
-
76
- input_text = prompt.format(pdf_text, query, " ")
77
- input_ids = tokenizer(input_text, return_tensors="pt")
78
-
79
- if torch.cuda.is_available():
80
- input_ids = input_ids.to("cuda")
81
-
82
- outputs = model.generate(
83
- **input_ids,
84
- max_new_tokens=500,
85
- no_repeat_ngram_size=5,
86
- )
87
-
88
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
-
90
- # Store the conversation
91
- st.session_state.conversation.insert(0, {"question": query, "answer": answer})
92
-
93
- # Display conversation
94
- if st.session_state.conversation:
95
- st.markdown("## Previous Conversations")
96
- for qa in st.session_state.conversation:
97
- st.markdown(f"**Q: {qa['question']}**")
98
- st.markdown(f"**A: {qa['answer']}**")
99
- st.markdown("---")
 
 
 
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
+ # Cache the model and tokenizer loading
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  @st.cache_resource
8
  def load_model():
9
  tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
 
16
  model.to("cuda")
17
  return tokenizer, model
18
 
19
+ # Cache the text extraction from PDFs
20
+ @st.cache_data
21
+ def extract_text_from_pdfs(files):
22
+ pdf_text = ""
23
+ for file in files:
24
+ reader = PdfReader(file)
25
+ for page_num in range(len(reader.pages)):
26
+ page = reader.pages[page_num]
27
+ pdf_text += page.extract_text() + "\n"
28
+ return pdf_text
29
+
30
+ # Load the model and tokenizer
31
+ tokenizer, model = load_model()
32
+
33
+ # Sidebar for PDF file upload
34
+ st.sidebar.title("πŸ“‚ Upload PDFs")
35
+ uploaded_files = st.sidebar.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
36
+
37
+ # Initialize session state
38
+ if "history" not in st.session_state:
39
+ st.session_state.history = []
40
+
41
+ # Extract text from PDFs and maintain session state
42
+ if uploaded_files:
43
+ if "pdf_text" not in st.session_state:
44
+ st.session_state.pdf_text = extract_text_from_pdfs(uploaded_files)
45
+
46
+ # Main interface
47
+ st.title("πŸ’¬ RAG PDF Chatbot")
48
+ st.markdown("Ask questions based on the uploaded PDF documents.")
49
+
50
+ # Input for user query
51
+ query = st.text_input("Enter your question:")
52
+
53
  # Process and respond to user query
54
  if st.button("Submit"):
55
  if uploaded_files and query:
56
+ with st.spinner("Generating response..."):
57
+ # Prepare the input data
58
+ prompt = """
59
+ ### Instruction and Input:
60
+ Based on the following context/document:
61
+ {}
62
+ Please answer the question: {}
63
+
64
+ ### Response:
65
+ """.format(st.session_state.pdf_text, query)
66
+
67
+ # Encode the input text
68
+ input_ids = tokenizer(prompt, return_tensors="pt")
69
+
70
+ # Use GPU for input ids if available
71
+ if torch.cuda.is_available():
72
+ input_ids = input_ids.to("cuda")
73
+
74
+ # Generate text using the model
75
+ outputs = model.generate(
76
+ **input_ids,
77
+ max_new_tokens=500, # Limit tokens to speed up generation
78
+ no_repeat_ngram_size=3, # Avoid repetition
79
+ do_sample=True, # Sampling for variability
80
+ temperature=0.7 # Control randomness
81
+ )
82
+
83
+ # Decode and display the results
84
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+ st.session_state.history.append({"question": query, "answer": response})
86
+
87
+ # Display chat history
88
+ if st.session_state.history:
89
+ for i, qa in enumerate(reversed(st.session_state.history), 1):
90
+ st.markdown(f"**Q{i}:** {qa['question']}")
91
+ st.markdown(f"**A{i}:** {qa['answer']}")
92
+
93
+ # Footer with author information
94
+ st.sidebar.markdown("### Created by: [Engr. Hamesh Raj](https://www.linkedin.com/in/datascientisthameshraj/)")
95
+ st.sidebar.markdown("## πŸ—‚οΈ RAG PDF Chatbot")