twimbit-ai commited on
Commit
c6f8a21
β€’
1 Parent(s): 3b44354

Create test_web_rag.py

Browse files
Files changed (1) hide show
  1. test_web_rag.py +263 -0
test_web_rag.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from urllib.parse import quote
3
+ from seleniumbase import SB
4
+ import markdownify
5
+ from bs4 import BeautifulSoup
6
+ from requests_html import HTMLSession
7
+ import html2text
8
+ import re
9
+ from openai import OpenAI
10
+ import tiktoken
11
+ from zenrows import ZenRowsClient
12
+ import requests
13
+ import os
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv()
17
+ ZENROWS_KEY = os.getenv('ZENROWS_KEY')
18
+ client = OpenAI()
19
+
20
+
21
+ def get_fast_url_source(url):
22
+ session = HTMLSession()
23
+ r = session.get(url)
24
+ return r.text
25
+
26
+
27
+ def convert_html_to_text(html):
28
+ h = html2text.HTML2Text()
29
+ h.body_width = 0 # Disable line wrapping
30
+ text = h.handle(html)
31
+ text = re.sub(r'\n\s*', '', text)
32
+ text = re.sub(r'\* \\', '', text)
33
+ " ".join(text.split())
34
+ return text
35
+
36
+
37
+ def get_google_search_url(query):
38
+ url = 'https://www.google.com/search?q=' + quote(query)
39
+ # Perform the request
40
+ request = urllib.request.Request(url)
41
+
42
+ # Set a normal User Agent header, otherwise Google will block the request.
43
+ request.add_header('User-Agent',
44
+ 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36')
45
+ raw_response = urllib.request.urlopen(request).read()
46
+
47
+ # Read the repsonse as a utf-8 string
48
+ html = raw_response.decode("utf-8")
49
+
50
+ # The code to get the html contents here.
51
+ soup = BeautifulSoup(html, 'html.parser')
52
+
53
+ # Find all the search result divs
54
+ divs = soup.select("#search div.g")
55
+ # print(divs)
56
+ url = []
57
+ for div in divs:
58
+ # Search for a h3 tag
59
+ results = div.select("h3")
60
+ urls = div.select('a')
61
+
62
+ # Check if we have found a result
63
+ # if (len(results) >= 1):
64
+ # # Print the title
65
+ # h3 = results[0]
66
+ # print(h3.get_text())
67
+
68
+ url.append(urls[0]['href'])
69
+ return url
70
+
71
+
72
+ def format_text(text):
73
+ soup = BeautifulSoup(text, 'html.parser')
74
+ results = soup.find_all(['p', 'h1', 'h2', 'span'])
75
+ text = ''
76
+ for key, result in enumerate(results):
77
+ if key % 2 == 0:
78
+ text = text + str(result) + '  '
79
+ else:
80
+ text = text + str(result) + '  '
81
+ return text
82
+
83
+
84
+ def get_page_source_selenium_base(url):
85
+ with SB(uc_cdp=True, guest_mode=True, headless=True) as sb:
86
+ sb.open(url)
87
+ sb.sleep(5)
88
+ page_source = sb.driver.get_page_source()
89
+ return page_source
90
+
91
+
92
+ def num_tokens_from_string(string: str, encoding_name: str) -> int:
93
+ encoding = tiktoken.get_encoding(encoding_name)
94
+ # encoding = tiktoken.encoding_for_model(encoding_name)
95
+ num_tokens = len(encoding.encode(string))
96
+ return num_tokens
97
+
98
+
99
+ def encoding_getter(encoding_type: str):
100
+ """
101
+ Returns the appropriate encoding based on the given encoding type (either an encoding string or a model name).
102
+ """
103
+ if "k_base" in encoding_type:
104
+ return tiktoken.get_encoding(encoding_type)
105
+ else:
106
+ return tiktoken.encoding_for_model(encoding_type)
107
+
108
+
109
+ def tokenizer(string: str, encoding_type: str) -> list:
110
+ """
111
+ Returns the tokens in a text string using the specified encoding.
112
+ """
113
+ encoding = encoding_getter(encoding_type)
114
+ tokens = encoding.encode(string)
115
+ return tokens
116
+
117
+
118
+ def token_counter(string: str, encoding_type: str) -> int:
119
+ """
120
+ Returns the number of tokens in a text string using the specified encoding.
121
+ """
122
+ num_tokens = len(tokenizer(string, encoding_type))
123
+ return num_tokens
124
+
125
+
126
+ def format_output(text):
127
+ page_source = format_text(text)
128
+ page_source = markdownify.markdownify(page_source)
129
+ # page_source = convert_html_to_text(page_source)
130
+ page_source = " ".join(page_source.split())
131
+ return page_source
132
+
133
+
134
+ def clean_text(text):
135
+ # Remove URLs
136
+ text = re.sub(r'http[s]?://\S+', '', text)
137
+
138
+ # Remove special characters and punctuation (keep only letters, numbers, and basic punctuation)
139
+ text = re.sub(r'[^a-zA-Z0-9\s,.!?-]', '', text)
140
+
141
+ # Normalize whitespace
142
+ text = re.sub(r'\s+', ' ', text).strip()
143
+
144
+ return text
145
+
146
+
147
+ def call_open_ai(system_prompt, max_tokens=800, stream=False):
148
+ messages = [
149
+ {
150
+ "role": "user",
151
+ "content": system_prompt
152
+ }
153
+ ]
154
+
155
+ stream = client.chat.completions.create(
156
+ model="gpt-3.5-turbo",
157
+ messages=messages,
158
+ temperature=0,
159
+ max_tokens=max_tokens,
160
+ top_p=0,
161
+ frequency_penalty=0,
162
+ presence_penalty=0,
163
+ stream=stream
164
+ )
165
+ return stream.choices[0].message.content
166
+
167
+
168
+ def url_summary(text, question):
169
+ system_prompt = """
170
+ Summarize the given text, please add all the important topics and numerical data.
171
+
172
+ While summarizing please keep this question in mind.
173
+ question:- {question}
174
+
175
+ text:
176
+ {text}
177
+ """.format(question=question, text=text)
178
+ return call_open_ai(system_prompt=system_prompt, max_tokens=800)
179
+
180
+
181
+ def get_google_search_query(question):
182
+ system_prompt = """
183
+ convert this question to the Google search query and return only query.
184
+ question:- {question}
185
+ """.format(question=question)
186
+
187
+ return call_open_ai(system_prompt=system_prompt, max_tokens=50)
188
+
189
+
190
+ def is_urlfile(url):
191
+ # Check if online file exists
192
+ try:
193
+ r = urllib.request.urlopen(url) # response
194
+ return r.getcode() == 200
195
+ except urllib.request.HTTPError:
196
+ return False
197
+
198
+
199
+ def check_url_pdf_file(url):
200
+ r = requests.get(url)
201
+ content_type = r.headers.get('content-type')
202
+
203
+ if 'application/pdf' in content_type:
204
+ return True
205
+ else:
206
+ return False
207
+
208
+
209
+ def zenrows_scrapper(url):
210
+ zen_client = ZenRowsClient(ZENROWS_KEY)
211
+ params = {"js_render": "true"}
212
+ response = zen_client.get(url, params=params)
213
+
214
+ return response.text
215
+
216
+
217
+ def get_new_question_from_history(pre_question, new_question, answer):
218
+ system_prompt = """
219
+ Generate a new Google search query using the previous question and answer. And return only the query.
220
+
221
+
222
+ previous question:- {pre_question}
223
+ answer:- {answer}
224
+
225
+ new question:- {new_question}
226
+ """.format(pre_question=pre_question, answer=answer, new_question=new_question)
227
+
228
+ return call_open_ai(system_prompt=system_prompt, max_tokens=50)
229
+
230
+
231
+ def get_docs_from_web(question, history, n_web_search, strategy):
232
+ if history:
233
+ question = get_new_question_from_history(history[0][0], question, history[0][1])
234
+ urls = get_google_search_url(get_google_search_query(question))[:n_web_search]
235
+ urls = list(set(urls))
236
+ docs = ''
237
+ yield f"Scraping started for {len(urls)} urls:-\n\n"
238
+ for key, url in enumerate(urls):
239
+ if '.pdf' in url:
240
+ yield f"Scraping skipped pdf detected. {key + 1}/{len(urls)} - {url} ❌\n"
241
+ continue
242
+
243
+ if strategy == 'Deep':
244
+ # page_source = get_page_source_selenium_base(url)
245
+ page_source = zenrows_scrapper(url)
246
+ formatted_page_source = format_output(page_source)
247
+ formatted_page_source = clean_text(formatted_page_source)
248
+ else:
249
+ page_source = get_fast_url_source(url)
250
+ formatted_page_source = format_output(page_source)
251
+ formatted_page_source = clean_text(formatted_page_source)
252
+
253
+ tokens = token_counter(formatted_page_source, 'gpt-3.5-turbo')
254
+
255
+ if tokens >= 15585:
256
+ yield f"Scraping skipped as token limit exceeded. {key + 1}/{len(urls)} - {url} ❌\n"
257
+ continue
258
+
259
+ summary = url_summary(formatted_page_source, question)
260
+ docs += summary
261
+ docs += '\n Source:-' + url + '\n\n'
262
+ yield f"Scraping Done {key + 1}/{len(urls)} - {url} βœ…\n"
263
+ yield {"data": docs}