|
import copy |
|
import re |
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, cast |
|
|
|
import numpy as np |
|
from langchain_community.utils.math import ( |
|
cosine_similarity, |
|
) |
|
from langchain_core.documents import BaseDocumentTransformer, Document |
|
from langchain_core.embeddings import Embeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
|
|
def calculate_cosine_distances(sentence_embeddings) -> np.array: |
|
"""Calculate cosine distances between sentences. |
|
|
|
Args: |
|
sentence_embeddings: List of sentence embeddings to calculate distances for. |
|
|
|
Returns: |
|
Distance between each pair of adjacent sentences |
|
""" |
|
return (1 - cosine_similarity(sentence_embeddings, sentence_embeddings)).flatten()[1::len(sentence_embeddings) + 1] |
|
|
|
|
|
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"] |
|
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = { |
|
"percentile": 95, |
|
"standard_deviation": 3, |
|
"interquartile": 1.5, |
|
} |
|
|
|
|
|
class BoundedSemanticChunker(BaseDocumentTransformer): |
|
"""First splits the text using semantic chunking according to the specified |
|
'breakpoint_threshold_amount', but then uses a RecursiveCharacterTextSplitter |
|
to split all chunks that are larger than 'max_chunk_size'. |
|
|
|
Adapted from langchain_experimental.text_splitter.SemanticChunker""" |
|
|
|
def __init__( |
|
self, |
|
embeddings: Embeddings, |
|
buffer_size: int = 1, |
|
add_start_index: bool = False, |
|
breakpoint_threshold_type: BreakpointThresholdType = "percentile", |
|
breakpoint_threshold_amount: Optional[float] = None, |
|
number_of_chunks: Optional[int] = None, |
|
max_chunk_size: int = 500, |
|
): |
|
self._add_start_index = add_start_index |
|
self.embeddings = embeddings |
|
self.buffer_size = buffer_size |
|
self.breakpoint_threshold_type = breakpoint_threshold_type |
|
self.number_of_chunks = number_of_chunks |
|
if breakpoint_threshold_amount is None: |
|
self.breakpoint_threshold_amount = BREAKPOINT_DEFAULTS[ |
|
breakpoint_threshold_type |
|
] |
|
else: |
|
self.breakpoint_threshold_amount = breakpoint_threshold_amount |
|
self.max_chunk_size = max_chunk_size |
|
|
|
self.sentence_split_regex = re.compile(r"(?<=[.?!])\s+") |
|
|
|
assert self.breakpoint_threshold_type == "percentile", "only breakpoint_threshold_type 'percentile' is currently supported" |
|
assert self.buffer_size == 1, "combining sentences is not supported yet" |
|
|
|
def _calculate_sentence_distances( |
|
self, sentences: List[dict] |
|
) -> Tuple[List[float], List[dict]]: |
|
"""Split text into multiple components.""" |
|
embeddings = self.embeddings.embed_documents(sentences) |
|
return calculate_cosine_distances(embeddings) |
|
|
|
def _calculate_breakpoint_threshold(self, distances: np.array, alt_breakpoint_threshold_amount=None) -> float: |
|
if alt_breakpoint_threshold_amount is None: |
|
breakpoint_threshold_amount = self.breakpoint_threshold_amount |
|
else: |
|
breakpoint_threshold_amount = alt_breakpoint_threshold_amount |
|
if self.breakpoint_threshold_type == "percentile": |
|
return cast( |
|
float, |
|
np.percentile(distances, breakpoint_threshold_amount), |
|
) |
|
elif self.breakpoint_threshold_type == "standard_deviation": |
|
return cast( |
|
float, |
|
np.mean(distances) |
|
+ breakpoint_threshold_amount * np.std(distances), |
|
) |
|
elif self.breakpoint_threshold_type == "interquartile": |
|
q1, q3 = np.percentile(distances, [25, 75]) |
|
iqr = q3 - q1 |
|
|
|
return np.mean(distances) + breakpoint_threshold_amount * iqr |
|
else: |
|
raise ValueError( |
|
f"Got unexpected `breakpoint_threshold_type`: " |
|
f"{self.breakpoint_threshold_type}" |
|
) |
|
|
|
def _threshold_from_clusters(self, distances: List[float]) -> float: |
|
""" |
|
Calculate the threshold based on the number of chunks. |
|
Inverse of percentile method. |
|
""" |
|
if self.number_of_chunks is None: |
|
raise ValueError( |
|
"This should never be called if `number_of_chunks` is None." |
|
) |
|
x1, y1 = len(distances), 0.0 |
|
x2, y2 = 1.0, 100.0 |
|
|
|
x = max(min(self.number_of_chunks, x1), x2) |
|
|
|
|
|
y = y1 + ((y2 - y1) / (x2 - x1)) * (x - x1) |
|
y = min(max(y, 0), 100) |
|
|
|
return cast(float, np.percentile(distances, y)) |
|
|
|
def split_text( |
|
self, |
|
text: str, |
|
) -> List[str]: |
|
sentences = self.sentence_split_regex.split(text) |
|
|
|
|
|
|
|
if len(sentences) == 1: |
|
return sentences |
|
|
|
bad_sentences = [] |
|
num_good_sentences = 0 |
|
|
|
distances = self._calculate_sentence_distances(sentences) |
|
|
|
if self.number_of_chunks is not None: |
|
breakpoint_distance_threshold = self._threshold_from_clusters(distances) |
|
else: |
|
breakpoint_distance_threshold = self._calculate_breakpoint_threshold( |
|
distances |
|
) |
|
|
|
indices_above_thresh = [ |
|
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold |
|
] |
|
|
|
chunks = [] |
|
start_index = 0 |
|
|
|
|
|
for index in indices_above_thresh: |
|
|
|
end_index = index |
|
|
|
|
|
group = sentences[start_index : end_index + 1] |
|
combined_text = " ".join(group) |
|
if len(combined_text) <= self.max_chunk_size: |
|
chunks.append(combined_text) |
|
else: |
|
sent_lengths = np.array([len(sd) for sd in group]) |
|
good_indices = np.flatnonzero(np.cumsum(sent_lengths) <= self.max_chunk_size) |
|
smaller_group = [group[i] for i in good_indices] |
|
if smaller_group: |
|
combined_text = " ".join(smaller_group) |
|
chunks.append(combined_text) |
|
group = group[good_indices[-1]:] |
|
bad_sentences.extend(group) |
|
|
|
|
|
start_index = index + 1 |
|
|
|
|
|
if start_index < len(sentences): |
|
group = sentences[start_index:] |
|
combined_text = " ".join(group) |
|
if len(combined_text) <= self.max_chunk_size: |
|
chunks.append(combined_text) |
|
else: |
|
sent_lengths = np.array([len(sd) for sd in group]) |
|
good_indices = np.flatnonzero(np.cumsum(sent_lengths) <= self.max_chunk_size) |
|
smaller_group = [group[i] for i in good_indices] |
|
if smaller_group: |
|
combined_text = " ".join(smaller_group) |
|
chunks.append(combined_text) |
|
group = group[good_indices[-1]:] |
|
bad_sentences.extend(group) |
|
|
|
|
|
|
|
if len(bad_sentences) > 0: |
|
recursive_splitter = RecursiveCharacterTextSplitter(chunk_size=self.max_chunk_size, chunk_overlap=10, |
|
separators=["\n\n", "\n", ".", ", ", " ", ""]) |
|
remaining_text = "".join(bad_sentences) |
|
chunks.extend(recursive_splitter.split_text(remaining_text)) |
|
return chunks |
|
|
|
def create_documents( |
|
self, texts: List[str], metadatas: Optional[List[dict]] = None |
|
) -> List[Document]: |
|
"""Create documents from a list of texts.""" |
|
_metadatas = metadatas or [{}] * len(texts) |
|
documents = [] |
|
for i, text in enumerate(texts): |
|
index = -1 |
|
for chunk in self.split_text(text): |
|
metadata = copy.deepcopy(_metadatas[i]) |
|
if self._add_start_index: |
|
index = text.find(chunk, index + 1) |
|
metadata["start_index"] = index |
|
new_doc = Document(page_content=chunk, metadata=metadata) |
|
documents.append(new_doc) |
|
return documents |
|
|
|
def split_documents(self, documents: Iterable[Document]) -> List[Document]: |
|
"""Split documents.""" |
|
texts, metadatas = [], [] |
|
for doc in documents: |
|
texts.append(doc.page_content) |
|
metadatas.append(doc.metadata) |
|
return self.create_documents(texts, metadatas=metadatas) |
|
|
|
def transform_documents( |
|
self, documents: Sequence[Document], **kwargs: Any |
|
) -> Sequence[Document]: |
|
"""Transform sequence of documents by splitting them.""" |
|
return self.split_documents(list(documents)) |
|
|
|
|
|
|
|
|