File size: 2,858 Bytes
d029425
d102e03
 
27e2360
d102e03
27e2360
 
d102e03
8d0a0d3
 
 
 
 
 
 
d102e03
27e2360
 
d029425
 
6bcf2e3
d029425
27e2360
 
 
 
d029425
 
6bcf2e3
d029425
27e2360
 
 
 
 
 
 
 
 
d102e03
 
27e2360
 
 
 
d102e03
 
 
27e2360
d102e03
 
27e2360
 
 
 
 
 
 
 
d102e03
 
27e2360
 
d102e03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfa084c
27e2360
dfa084c
 
27e2360
d102e03
 
8d0a0d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Generator, Dict, List

from googleapiclient.discovery import build
from streamlit import secrets

INSTRUCTIONS = "Instructions: " \
               "Using the provided web search results, " \
               "write a comprehensive reply to the given query. " \
               "Make sure to cite results using [[number](URL)] notation after the reference. " \
               "If the provided search results refer to multiple subjects with the same name, " \
               "write separate answers for each subject."


def get_google_api_key():
    """Returns the Google API key from streamlit's secrets"""
    try:
        return secrets["google_search_api_key"]
    except (FileNotFoundError, IsADirectoryError):
        return os.environ["google_search_api_key"]


def get_google_cse_id():
    """Returns the Google CSE ID from streamlit's secrets"""
    try:
        return secrets["google_cse_id"]
    except (FileNotFoundError, IsADirectoryError):
        return os.environ["google_cse_id"]


def google_search(search_term, **kwargs) -> list:
    service = build("customsearch", "v1", developerKey=get_google_api_key())
    search_engine = service.cse()
    res = search_engine.list(q=search_term, cx=get_google_cse_id(), **kwargs).execute()
    return res['items']


@dataclass
class SearchResult:
    __slots__ = ["title", "body", "url"]
    title: str
    body: str
    url: str


def get_web_search_results(
        query: str,
        num_results: int,
) -> Generator[SearchResult, None, None]:
    """Gets a list of web search results using the Google search API"""
    rew_results: List[Dict[str, str]] = google_search(
        search_term=query,
        num=num_results
    )[:num_results]
    for result in rew_results:
        if result["snippet"].endswith("\xa0..."):
            result["snippet"] = result["snippet"][:-4]
        yield SearchResult(
            title=result["title"],
            body=result["snippet"],
            url=result["link"],
        )


def format_search_result(search_result: Generator[SearchResult, None, None]) -> str:
    """Formats a search result to be added to the prompt."""
    ans = ""
    for i, result in enumerate(search_result):
        ans += f"[{i}] {result.body}\nURL: {result.url}\n\n"
    return ans


def rewrite_prompt(
        prompt: str,
) -> str:
    """Rewrites the prompt by adding web search results to it."""
    raw_results = get_web_search_results(
        query=prompt,
        num_results=5,
    )
    formatted_results = "Web search results:\n" + format_search_result(raw_results)
    formatted_date = "Current date: " + datetime.now().strftime("%d/%m/%Y")
    formatted_prompt = f"Query: {prompt}"
    return "\n".join([formatted_results, formatted_date, INSTRUCTIONS, formatted_prompt])