yonikremer
commited on
Commit
•
d102e03
1
Parent(s):
c9089bd
Added search engine prompt engineering
Browse files- hanlde_form_submit.py +6 -2
- prompt_engeneering.py +53 -0
hanlde_form_submit.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
from grouped_sampling import GroupedSamplingPipeLine
|
3 |
|
|
|
4 |
from supported_models import get_supported_model_names
|
5 |
|
6 |
|
@@ -15,7 +16,7 @@ def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
|
|
15 |
:return: A pipeline with the given model name and group size.
|
16 |
"""
|
17 |
print(f"Starts downloading model: {model_name} from the internet.")
|
18 |
-
pipeline =
|
19 |
model_name=model_name,
|
20 |
group_size=group_size,
|
21 |
end_of_sentence_stop=True,
|
@@ -38,9 +39,12 @@ def generate_text(
|
|
38 |
:param output_length: The size of the groups to use.
|
39 |
:return: The generated text.
|
40 |
"""
|
|
|
41 |
return pipeline(
|
42 |
-
prompt_s=
|
43 |
max_new_tokens=output_length,
|
|
|
|
|
44 |
)["generated_text"]
|
45 |
|
46 |
|
|
|
1 |
import streamlit as st
|
2 |
from grouped_sampling import GroupedSamplingPipeLine
|
3 |
|
4 |
+
from prompt_engeneering import rewrite_prompt
|
5 |
from supported_models import get_supported_model_names
|
6 |
|
7 |
|
|
|
16 |
:return: A pipeline with the given model name and group size.
|
17 |
"""
|
18 |
print(f"Starts downloading model: {model_name} from the internet.")
|
19 |
+
pipeline = GroupedSamplingPipeLine(
|
20 |
model_name=model_name,
|
21 |
group_size=group_size,
|
22 |
end_of_sentence_stop=True,
|
|
|
39 |
:param output_length: The size of the groups to use.
|
40 |
:return: The generated text.
|
41 |
"""
|
42 |
+
better_prompt = rewrite_prompt(prompt, 5)
|
43 |
return pipeline(
|
44 |
+
prompt_s=better_prompt,
|
45 |
max_new_tokens=output_length,
|
46 |
+
return_text=True,
|
47 |
+
return_full_text=False,
|
48 |
)["generated_text"]
|
49 |
|
50 |
|
prompt_engeneering.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from datetime import datetime
|
3 |
+
from typing import Generator
|
4 |
+
|
5 |
+
import requests
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class SearchResult:
|
10 |
+
def __init__(self, title: str, body: str, url: str):
|
11 |
+
self.title = title
|
12 |
+
self.body = body
|
13 |
+
self.url = url
|
14 |
+
|
15 |
+
|
16 |
+
def get_web_search_results(
|
17 |
+
prompt: str,
|
18 |
+
num_results: int,
|
19 |
+
) -> Generator[SearchResult, None, None]:
|
20 |
+
"""Adds web search results to the prompt.
|
21 |
+
Using """
|
22 |
+
url = f"https://ddg-webapp-aagd.vercel.app/search?max_results={num_results}&q=${prompt}"
|
23 |
+
response = requests.get(url)
|
24 |
+
if response.status_code != 200:
|
25 |
+
raise ValueError(f"Failed to get web search results for prompt: {prompt}")
|
26 |
+
results = response.json()
|
27 |
+
for result in results:
|
28 |
+
yield SearchResult(
|
29 |
+
title=result["title"],
|
30 |
+
body=result["body"],
|
31 |
+
url=result["href"],
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def format_search_result(search_result: Generator[SearchResult, None, None]) -> str:
|
36 |
+
"""Formats a search result to be added to the prompt."""
|
37 |
+
ans = ""
|
38 |
+
for i, result in enumerate(search_result):
|
39 |
+
ans += f"[{i}] {result.body}\nURL: {result.url}\n\n"
|
40 |
+
return ans
|
41 |
+
|
42 |
+
|
43 |
+
def rewrite_prompt(
|
44 |
+
prompt: str,
|
45 |
+
num_results: int,
|
46 |
+
) -> str:
|
47 |
+
"""Rewrites the prompt by adding web search results to it."""
|
48 |
+
raw_results = get_web_search_results(prompt, num_results)
|
49 |
+
formatted_results = "Web search results: " + format_search_result(raw_results)
|
50 |
+
formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
|
51 |
+
default_instructions = "Instructions: Using the provided web search results, write a comprehensive reply to the given query. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject."
|
52 |
+
formatted_prompt = f"Query: {prompt}"
|
53 |
+
return "\n".join([formatted_results, formatted_date, default_instructions, formatted_prompt])
|