yonikremer commited on
Commit
d102e03
1 Parent(s): c9089bd

Added search engine prompt engineering

Browse files
Files changed (2) hide show
  1. hanlde_form_submit.py +6 -2
  2. prompt_engeneering.py +53 -0
hanlde_form_submit.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from grouped_sampling import GroupedSamplingPipeLine
3
 
 
4
  from supported_models import get_supported_model_names
5
 
6
 
@@ -15,7 +16,7 @@ def create_pipeline(model_name: str, group_size) -> GroupedSamplingPipeLine:
15
  :return: A pipeline with the given model name and group size.
16
  """
17
  print(f"Starts downloading model: {model_name} from the internet.")
18
- pipeline = GroupedSamplingPipeLine(
19
  model_name=model_name,
20
  group_size=group_size,
21
  end_of_sentence_stop=True,
@@ -38,9 +39,12 @@ def generate_text(
38
  :param output_length: The size of the groups to use.
39
  :return: The generated text.
40
  """
 
41
  return pipeline(
42
- prompt_s=prompt,
43
  max_new_tokens=output_length,
 
 
44
  )["generated_text"]
45
 
46
 
 
1
  import streamlit as st
2
  from grouped_sampling import GroupedSamplingPipeLine
3
 
4
+ from prompt_engeneering import rewrite_prompt
5
  from supported_models import get_supported_model_names
6
 
7
 
 
16
  :return: A pipeline with the given model name and group size.
17
  """
18
  print(f"Starts downloading model: {model_name} from the internet.")
19
+ pipeline = GroupedSamplingPipeLine(
20
  model_name=model_name,
21
  group_size=group_size,
22
  end_of_sentence_stop=True,
 
39
  :param output_length: The size of the groups to use.
40
  :return: The generated text.
41
  """
42
+ better_prompt = rewrite_prompt(prompt, 5)
43
  return pipeline(
44
+ prompt_s=better_prompt,
45
  max_new_tokens=output_length,
46
+ return_text=True,
47
+ return_full_text=False,
48
  )["generated_text"]
49
 
50
 
prompt_engeneering.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from datetime import datetime
3
+ from typing import Generator
4
+
5
+ import requests
6
+
7
+
8
+ @dataclass
9
+ class SearchResult:
10
+ def __init__(self, title: str, body: str, url: str):
11
+ self.title = title
12
+ self.body = body
13
+ self.url = url
14
+
15
+
16
+ def get_web_search_results(
17
+ prompt: str,
18
+ num_results: int,
19
+ ) -> Generator[SearchResult, None, None]:
20
+ """Adds web search results to the prompt.
21
+ Using """
22
+ url = f"https://ddg-webapp-aagd.vercel.app/search?max_results={num_results}&q=${prompt}"
23
+ response = requests.get(url)
24
+ if response.status_code != 200:
25
+ raise ValueError(f"Failed to get web search results for prompt: {prompt}")
26
+ results = response.json()
27
+ for result in results:
28
+ yield SearchResult(
29
+ title=result["title"],
30
+ body=result["body"],
31
+ url=result["href"],
32
+ )
33
+
34
+
35
+ def format_search_result(search_result: Generator[SearchResult, None, None]) -> str:
36
+ """Formats a search result to be added to the prompt."""
37
+ ans = ""
38
+ for i, result in enumerate(search_result):
39
+ ans += f"[{i}] {result.body}\nURL: {result.url}\n\n"
40
+ return ans
41
+
42
+
43
+ def rewrite_prompt(
44
+ prompt: str,
45
+ num_results: int,
46
+ ) -> str:
47
+ """Rewrites the prompt by adding web search results to it."""
48
+ raw_results = get_web_search_results(prompt, num_results)
49
+ formatted_results = "Web search results: " + format_search_result(raw_results)
50
+ formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
51
+ default_instructions = "Instructions: Using the provided web search results, write a comprehensive reply to the given query. Make sure to cite results using [[number](URL)] notation after the reference. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject."
52
+ formatted_prompt = f"Query: {prompt}"
53
+ return "\n".join([formatted_results, formatted_date, default_instructions, formatted_prompt])