datascientist22 commited on
Commit
7bdff6e
1 Parent(s): 7ff270d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -34
app.py CHANGED
@@ -47,20 +47,25 @@ if 'chat_history' not in st.session_state:
47
  st.session_state.chat_history = []
48
 
49
  # Load the tokenizer and model
50
- tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
51
- model = AutoModelForCausalLM.from_pretrained("himmeow/vi-gemma-2b-RAG")
52
-
53
- # Use GPU if available
54
- device = "cuda" if torch.cuda.is_available() else "cpu"
55
- model = model.to(device)
 
 
56
 
57
  # Function to extract text from PDF files
58
  def extract_text_from_pdfs(files):
59
  text = ""
60
  for uploaded_file in files:
61
- reader = PdfReader(uploaded_file)
62
- for page in reader.pages:
63
- text += page.extract_text() + "\n"
 
 
 
64
  return text
65
 
66
  # Handle the query submission
@@ -73,33 +78,35 @@ if submit_button:
73
  try:
74
  # Extract text from uploaded PDFs
75
  pdf_text = extract_text_from_pdfs(uploaded_files)
76
-
77
- # Prepare the input prompt
78
- prompt = f"""
79
- Based on the following context/document:
80
- {pdf_text}
81
- Please answer the question: {query}
82
- """
83
-
84
- # Encode the input text
85
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
86
-
87
- # Generate the response
88
- outputs = model.generate(
89
- input_ids=input_ids,
90
- max_new_tokens=500,
91
- no_repeat_ngram_size=5,
92
- )
93
-
94
- # Decode the response and clean it
95
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
- clean_response = response.strip()
97
-
98
- # Update chat history
99
- st.session_state.chat_history.append((query, clean_response))
 
 
100
 
101
  except Exception as e:
102
- st.error(f"An error occurred: {e}")
103
 
104
  # Display chat history
105
  if st.session_state.chat_history:
 
47
  st.session_state.chat_history = []
48
 
49
  # Load the tokenizer and model
50
+ try:
51
+ tokenizer = AutoTokenizer.from_pretrained("himmeow/vi-gemma-2b-RAG")
52
+ model = AutoModelForCausalLM.from_pretrained("himmeow/vi-gemma-2b-RAG")
53
+ device = "cuda" if torch.cuda.is_available() else "cpu"
54
+ model = model.to(device)
55
+ except Exception as e:
56
+ st.error(f"Error loading model or tokenizer: {e}")
57
+ st.stop()
58
 
59
  # Function to extract text from PDF files
60
  def extract_text_from_pdfs(files):
61
  text = ""
62
  for uploaded_file in files:
63
+ try:
64
+ reader = PdfReader(uploaded_file)
65
+ for page in reader.pages:
66
+ text += page.extract_text() + "\n"
67
+ except Exception as e:
68
+ st.error(f"Error reading PDF file: {e}")
69
  return text
70
 
71
  # Handle the query submission
 
78
  try:
79
  # Extract text from uploaded PDFs
80
  pdf_text = extract_text_from_pdfs(uploaded_files)
81
+ if not pdf_text.strip():
82
+ st.warning("⚠️ No text found in the uploaded PDFs.")
83
+ else:
84
+ # Prepare the input prompt
85
+ prompt = f"""
86
+ Based on the following context/document:
87
+ {pdf_text}
88
+ Please answer the question: {query}
89
+ """
90
+
91
+ # Encode the input text
92
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=tokenizer.model_max_length)
93
+
94
+ # Generate the response
95
+ outputs = model.generate(
96
+ input_ids=inputs['input_ids'].to(device),
97
+ max_new_tokens=500,
98
+ no_repeat_ngram_size=5,
99
+ )
100
+
101
+ # Decode the response and clean it
102
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
103
+ clean_response = response.strip()
104
+
105
+ # Update chat history
106
+ st.session_state.chat_history.append((query, clean_response))
107
 
108
  except Exception as e:
109
+ st.error(f"An error occurred during processing: {e}")
110
 
111
  # Display chat history
112
  if st.session_state.chat_history: