isayahc commited on
Commit
3b85ad0
1 Parent(s): dde4362

making classes and test for vectorstore handling

Browse files
pytest.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [pytest]
2
+ pythonpath = .
rag_app/vector_store_handler/__init__.py ADDED
File without changes
rag_app/vector_store_handler/vectorstores.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from langchain.vectorstores import Chroma, FAISS
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain.text_splitter import CharacterTextSplitter
5
+ from langchain.document_loaders import TextLoader
6
+
7
+ class BaseVectorStore(ABC):
8
+ """
9
+ Abstract base class for vector stores.
10
+
11
+ This class defines the interface for vector stores and implements
12
+ common functionality.
13
+ """
14
+
15
+ def __init__(self, embedding_model, persist_directory=None):
16
+ """
17
+ Initialize the BaseVectorStore.
18
+
19
+ Args:
20
+ embedding_model: The embedding model to use for vectorizing text.
21
+ persist_directory (str, optional): Directory to persist the vector store.
22
+ """
23
+ self.persist_directory = persist_directory
24
+ self.embeddings = embedding_model
25
+ self.vectorstore = None
26
+
27
+ def load_and_process_documents(self, file_path, chunk_size=1000, chunk_overlap=0):
28
+ """
29
+ Load and process documents from a file.
30
+
31
+ Args:
32
+ file_path (str): Path to the file to load.
33
+ chunk_size (int): Size of text chunks for processing.
34
+ chunk_overlap (int): Overlap between chunks.
35
+
36
+ Returns:
37
+ list: Processed documents.
38
+ """
39
+ loader = TextLoader(file_path)
40
+ documents = loader.load()
41
+ text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
42
+ return text_splitter.split_documents(documents)
43
+
44
+ @abstractmethod
45
+ def create_vectorstore(self, texts):
46
+ """
47
+ Create a new vector store from the given texts.
48
+
49
+ Args:
50
+ texts (list): List of texts to vectorize and store.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def load_existing_vectorstore(self):
56
+ """
57
+ Load an existing vector store from the persist directory.
58
+ """
59
+ pass
60
+
61
+ def similarity_search(self, query):
62
+ """
63
+ Perform a similarity search on the vector store.
64
+
65
+ Args:
66
+ query (str): The query text to search for.
67
+
68
+ Returns:
69
+ list: Search results.
70
+
71
+ Raises:
72
+ ValueError: If the vector store is not initialized.
73
+ """
74
+ if not self.vectorstore:
75
+ raise ValueError("Vector store not initialized. Call create_vectorstore or load_existing_vectorstore first.")
76
+ return self.vectorstore.similarity_search(query)
77
+
78
+ @abstractmethod
79
+ def save(self):
80
+ """
81
+ Save the current state of the vector store.
82
+ """
83
+ pass
84
+
85
+ class ChromaVectorStore(BaseVectorStore):
86
+ """
87
+ Implementation of BaseVectorStore using Chroma as the backend.
88
+ """
89
+
90
+ def create_vectorstore(self, texts):
91
+ """
92
+ Create a new Chroma vector store from the given texts.
93
+
94
+ Args:
95
+ texts (list): List of texts to vectorize and store.
96
+ """
97
+ self.vectorstore = Chroma.from_documents(
98
+ texts,
99
+ self.embeddings,
100
+ persist_directory=self.persist_directory
101
+ )
102
+
103
+ def load_existing_vectorstore(self):
104
+ """
105
+ Load an existing Chroma vector store from the persist directory.
106
+
107
+ Raises:
108
+ ValueError: If persist_directory is not set.
109
+ """
110
+ if self.persist_directory:
111
+ self.vectorstore = Chroma(
112
+ persist_directory=self.persist_directory,
113
+ embedding_function=self.embeddings
114
+ )
115
+ else:
116
+ raise ValueError("Persist directory is required for loading Chroma.")
117
+
118
+ def save(self):
119
+ """
120
+ Save the current state of the Chroma vector store.
121
+
122
+ Raises:
123
+ ValueError: If the vector store is not initialized.
124
+ """
125
+ if not self.vectorstore:
126
+ raise ValueError("Vector store not initialized. Nothing to save.")
127
+ self.vectorstore.persist()
128
+
129
+ class FAISSVectorStore(BaseVectorStore):
130
+ """
131
+ Implementation of BaseVectorStore using FAISS as the backend.
132
+ """
133
+
134
+ def create_vectorstore(self, texts):
135
+ """
136
+ Create a new FAISS vector store from the given texts.
137
+
138
+ Args:
139
+ texts (list): List of texts to vectorize and store.
140
+ """
141
+ self.vectorstore = FAISS.from_documents(texts, self.embeddings)
142
+
143
+ def load_existing_vectorstore(self):
144
+ """
145
+ Load an existing FAISS vector store from the persist directory.
146
+
147
+ Raises:
148
+ ValueError: If persist_directory is not set.
149
+ """
150
+ if self.persist_directory:
151
+ self.vectorstore = FAISS.load_local(self.persist_directory, self.embeddings)
152
+ else:
153
+ raise ValueError("Persist directory is required for loading FAISS.")
154
+
155
+ def save(self):
156
+ """
157
+ Save the current state of the FAISS vector store.
158
+
159
+ Raises:
160
+ ValueError: If the vector store is not initialized.
161
+ """
162
+ if not self.vectorstore:
163
+ raise ValueError("Vector store not initialized. Nothing to save.")
164
+ self.vectorstore.save_local(self.persist_directory)
165
+
166
+ # Usage example:
167
+ def main():
168
+ """
169
+ Example usage of the vector store classes.
170
+ """
171
+ # Create an embedding model
172
+ embedding_model = OpenAIEmbeddings()
173
+
174
+ # Using Chroma
175
+ chroma_store = ChromaVectorStore(embedding_model, persist_directory="./chroma_store")
176
+ texts = chroma_store.load_and_process_documents("path/to/your/file.txt")
177
+ chroma_store.create_vectorstore(texts)
178
+ results = chroma_store.similarity_search("Your query here")
179
+ print("Chroma results:", results[0].page_content)
180
+ chroma_store.save()
181
+
182
+ # Load existing Chroma store
183
+ existing_chroma = ChromaVectorStore(embedding_model, persist_directory="./chroma_store")
184
+ existing_chroma.load_existing_vectorstore()
185
+ results = existing_chroma.similarity_search("Another query")
186
+ print("Existing Chroma results:", results[0].page_content)
187
+
188
+ # Using FAISS
189
+ faiss_store = FAISSVectorStore(embedding_model, persist_directory="./faiss_store")
190
+ texts = faiss_store.load_and_process_documents("path/to/your/file.txt")
191
+ faiss_store.create_vectorstore(texts)
192
+ results = faiss_store.similarity_search("Your query here")
193
+ print("FAISS results:", results[0].page_content)
194
+ faiss_store.save()
195
+
196
+ # Load existing FAISS store
197
+ existing_faiss = FAISSVectorStore(embedding_model, persist_directory="./faiss_store")
198
+ existing_faiss.load_existing_vectorstore()
199
+ results = existing_faiss.similarity_search("Another query")
200
+ print("Existing FAISS results:", results[0].page_content)
201
+
202
+ if __name__ == "__main__":
203
+ main()
tests/vector_store_handler/test_vectorstores.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import MagicMock, patch
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain.schema import Document
5
+
6
+ # Update the import to reflect your project structure
7
+ from rag_app.vector_store_handler.vectorstores import BaseVectorStore, ChromaVectorStore, FAISSVectorStore
8
+
9
+ class TestBaseVectorStore(unittest.TestCase):
10
+ def setUp(self):
11
+ self.embedding_model = MagicMock(spec=OpenAIEmbeddings)
12
+ self.base_store = BaseVectorStore(self.embedding_model, "test_dir")
13
+
14
+ def test_init(self):
15
+ self.assertEqual(self.base_store.persist_directory, "test_dir")
16
+ self.assertEqual(self.base_store.embeddings, self.embedding_model)
17
+ self.assertIsNone(self.base_store.vectorstore)
18
+
19
+ @patch('rag_app.vector_store_handler.vectorstores.TextLoader')
20
+ @patch('rag_app.vector_store_handler.vectorstores.CharacterTextSplitter')
21
+ def test_load_and_process_documents(self, mock_splitter, mock_loader):
22
+ mock_loader.return_value.load.return_value = ["doc1", "doc2"]
23
+ mock_splitter.return_value.split_documents.return_value = ["split1", "split2"]
24
+
25
+ result = self.base_store.load_and_process_documents("test.txt")
26
+
27
+ mock_loader.assert_called_once_with("test.txt")
28
+ mock_splitter.assert_called_once_with(chunk_size=1000, chunk_overlap=0)
29
+ self.assertEqual(result, ["split1", "split2"])
30
+
31
+ def test_similarity_search_not_initialized(self):
32
+ with self.assertRaises(ValueError):
33
+ self.base_store.similarity_search("query")
34
+
35
+ class TestChromaVectorStore(unittest.TestCase):
36
+ def setUp(self):
37
+ self.embedding_model = MagicMock(spec=OpenAIEmbeddings)
38
+ self.chroma_store = ChromaVectorStore(self.embedding_model, "test_dir")
39
+
40
+ @patch('rag_app.vector_store_handler.vectorstores.Chroma')
41
+ def test_create_vectorstore(self, mock_chroma):
42
+ texts = [Document(page_content="test")]
43
+ self.chroma_store.create_vectorstore(texts)
44
+ mock_chroma.from_documents.assert_called_once_with(
45
+ texts,
46
+ self.embedding_model,
47
+ persist_directory="test_dir"
48
+ )
49
+
50
+ @patch('rag_app.vector_store_handler.vectorstores.Chroma')
51
+ def test_load_existing_vectorstore(self, mock_chroma):
52
+ self.chroma_store.load_existing_vectorstore()
53
+ mock_chroma.assert_called_once_with(
54
+ persist_directory="test_dir",
55
+ embedding_function=self.embedding_model
56
+ )
57
+
58
+ def test_save(self):
59
+ self.chroma_store.vectorstore = MagicMock()
60
+ self.chroma_store.save()
61
+ self.chroma_store.vectorstore.persist.assert_called_once()
62
+
63
+ class TestFAISSVectorStore(unittest.TestCase):
64
+ def setUp(self):
65
+ self.embedding_model = MagicMock(spec=OpenAIEmbeddings)
66
+ self.faiss_store = FAISSVectorStore(self.embedding_model, "test_dir")
67
+
68
+ @patch('rag_app.vector_store_handler.vectorstores.FAISS')
69
+ def test_create_vectorstore(self, mock_faiss):
70
+ texts = [Document(page_content="test")]
71
+ self.faiss_store.create_vectorstore(texts)
72
+ mock_faiss.from_documents.assert_called_once_with(texts, self.embedding_model)
73
+
74
+ @patch('rag_app.vector_store_handler.vectorstores.FAISS')
75
+ def test_load_existing_vectorstore(self, mock_faiss):
76
+ self.faiss_store.load_existing_vectorstore()
77
+ mock_faiss.load_local.assert_called_once_with("test_dir", self.embedding_model)
78
+
79
+ @patch('rag_app.vector_store_handler.vectorstores.FAISS')
80
+ def test_save(self, mock_faiss):
81
+ self.faiss_store.vectorstore = MagicMock()
82
+ self.faiss_store.save()
83
+ self.faiss_store.vectorstore.save_local.assert_called_once_with("test_dir")
84
+
85
+ if __name__ == '__main__':
86
+ unittest.main()