|
"""Util that calls Wikipedia. references: https://github.com/hwchase17/langchain/blob/9b615022e2b6a3591347ad77a3e21aad6cf24c49/docs/extras/modules/agents/tools/integrations/wikipedia.ipynb#L36""" |
|
import logging |
|
from typing import Any, Dict, List, Optional |
|
|
|
from pydantic import BaseModel, root_validator |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
WIKIPEDIA_MAX_QUERY_LENGTH = 300 |
|
|
|
|
|
class WikipediaAPIWrapper(BaseModel): |
|
"""Wrapper around WikipediaAPI. |
|
|
|
To use, you should have the ``wikipedia`` python package installed. |
|
This wrapper will use the Wikipedia API to conduct searches and |
|
fetch page summaries. By default, it will return the page summaries |
|
of the top-k results. |
|
It limits the Document content by doc_content_chars_max. |
|
|
|
:param top_k_results: The number of results to return. |
|
:type top_k_results: int |
|
:param lang: The language to use for the Wikipedia API. |
|
:type lang: str |
|
:param doc_content_chars_max: The maximum number of characters in the Document content. |
|
:type doc_content_chars_max: int |
|
:wiki_client: The Wikipedia API client. |
|
""" |
|
|
|
wiki_client: Any |
|
top_k_results: int = 5 |
|
lang: str = "en" |
|
doc_content_chars_max: int = 4000 |
|
|
|
@root_validator(pre=True) |
|
def validate_environment(cls, values: Dict) -> Dict: |
|
"""Validate that the python package exists in environment. |
|
|
|
:param values: The values to validate. |
|
:type values: Dict |
|
:return: The validated values. |
|
:rtype: Dict |
|
:raises ImportError: If the package is not installed. |
|
""" |
|
try: |
|
import wikipedia |
|
|
|
wikipedia.set_lang(values["lang"]) |
|
values["wiki_client"] = wikipedia |
|
except ImportError: |
|
raise ImportError( |
|
"Could not import wikipedia python package. " |
|
"Please install it with `pip install wikipedia`." |
|
) |
|
return values |
|
|
|
def run(self, query: str) -> str: |
|
"""Run Wikipedia search and get page summaries. |
|
|
|
:param query: The query to search for. |
|
:type query: str |
|
:return: The page summaries. |
|
:rtype: str |
|
""" |
|
|
|
page_titles = self.search_page_titles(query) |
|
summaries = [] |
|
for page_title in page_titles: |
|
if wiki_page := self._fetch_page(page_title): |
|
if summary := self._formatted_page_summary(page_title, wiki_page): |
|
summaries.append(summary) |
|
if not summaries: |
|
return "No good Wikipedia Search Result was found" |
|
return "\n\n".join(summaries)[: self.doc_content_chars_max] |
|
|
|
def _fetch_page(self, page: str) -> Optional[str]: |
|
""" Fetch page content from Wikipedia. |
|
|
|
:param page: The page to fetch. |
|
:type page: str |
|
:return: The page content. |
|
:rtype: Optional[str] |
|
""" |
|
try: |
|
return self.wiki_client.page(title=page, auto_suggest=False).content[: self.doc_content_chars_max] |
|
except ( |
|
self.wiki_client.exceptions.PageError, |
|
self.wiki_client.exceptions.DisambiguationError, |
|
): |
|
return None |
|
|
|
def search_page_titles(self, query: str) -> List[str]: |
|
"""Run Wikipedia search and get page summaries. |
|
|
|
:param query: The query to search for. |
|
:type query: str |
|
:return: The page titles. |
|
:rtype: List[str] |
|
""" |
|
|
|
return self.wiki_client.search(query[:WIKIPEDIA_MAX_QUERY_LENGTH])[:self.top_k_results] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def _formatted_page_summary(page_title: str, wiki_page: Any) -> Optional[str]: |
|
""" Format the page and summary in a single string. |
|
|
|
:param page_title: The page title. |
|
:type page_title: str |
|
:param wiki_page: The Wikipedia page. |
|
:type wiki_page: Any |
|
:return: The formatted page summary. |
|
:rtype: Optional[str] |
|
""" |
|
return f"Page: {page_title}\nSummary: {wiki_page.summary}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|