isayahc commited on
Commit
fdb6484
1 Parent(s): d7a243c

did some refactoring added documentation

Browse files
app_gui.py CHANGED
@@ -1,16 +1,27 @@
1
  # Import Gradio for UI, along with other necessary libraries
2
  import gradio as gr
3
  from rag_app.agents.react_agent import agent_executor
 
4
  # need to import the qa!
 
 
5
 
6
- # Function to add a new input to the chat history
7
  def add_text(history, text):
 
 
 
 
 
8
  # Append the new text to the history with a placeholder for the response
9
  history = history + [(text, None)]
10
  return history, ""
11
 
12
- # Function representing the bot's response mechanism
13
  def bot(history):
 
 
 
 
14
  # Obtain the response from the 'infer' function using the latest input
15
  response = infer(history[-1][0], history)
16
  #sources = [doc.metadata.get("source") for doc in response['source_documents']]
@@ -23,10 +34,13 @@ def bot(history):
23
  history[-1][1] = response['output']
24
  return history
25
 
26
- # Function to infer the response using the RAG model
27
  def infer(question, history):
 
 
 
 
28
  # Use the question and history to query the RAG model
29
- #result = qa({"query": question, "history": history, "question": question})
30
  try:
31
  result = agent_executor.invoke(
32
  {
@@ -37,6 +51,8 @@ def infer(question, history):
37
  return result
38
  except Exception:
39
  raise gr.Error("Model is Overloaded, Please retry later!")
 
 
40
 
41
  # CSS styling for the Gradio interface
42
  css = """
 
1
  # Import Gradio for UI, along with other necessary libraries
2
  import gradio as gr
3
  from rag_app.agents.react_agent import agent_executor
4
+ from config import db
5
  # need to import the qa!
6
+ db.create_new_session()
7
+
8
 
 
9
  def add_text(history, text):
10
+ """Function to add a new input to the chat history
11
+
12
+ Return: return_description
13
+ """
14
+
15
  # Append the new text to the history with a placeholder for the response
16
  history = history + [(text, None)]
17
  return history, ""
18
 
19
+
20
  def bot(history):
21
+ """Function representing the bot's response mechanism
22
+
23
+ """
24
+
25
  # Obtain the response from the 'infer' function using the latest input
26
  response = infer(history[-1][0], history)
27
  #sources = [doc.metadata.get("source") for doc in response['source_documents']]
 
34
  history[-1][1] = response['output']
35
  return history
36
 
37
+
38
  def infer(question, history):
39
+ """Function to infer the response using the RAG model
40
+
41
+ """
42
+
43
  # Use the question and history to query the RAG model
 
44
  try:
45
  result = agent_executor.invoke(
46
  {
 
51
  return result
52
  except Exception:
53
  raise gr.Error("Model is Overloaded, Please retry later!")
54
+
55
+
56
 
57
  # CSS styling for the Gradio interface
58
  css = """
config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from rag_app.database.db_handler import DataBaseHandler
4
+
5
+ load_dotenv()
6
+
7
+ sqlite_file_name = os.getenv('SOURCES_CACHE')
8
+
9
+
10
+ db = DataBaseHandler()
11
+
12
+ db.create_all_tables()
13
+
rag_app/database/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from rag_app.database.db_handler import DataBaseHandler
rag_app/database/db_handler.py CHANGED
@@ -3,182 +3,222 @@ from rag_app.database.schema import Sources
3
  from rag_app.utils.logger import get_console_logger
4
  import os
5
  from dotenv import load_dotenv
6
-
7
- load_dotenv()
8
-
9
- sqlite_file_name = os.getenv('SOURCES_CACHE')
10
-
11
- sqlite_url = f"sqlite:///{sqlite_file_name}"
12
- engine = create_engine(sqlite_url, echo=False)
13
-
14
- logger = get_console_logger("db_handler")
15
-
16
- SQLModel.metadata.create_all(engine)
17
-
18
-
19
- def read_one(hash_id: dict):
20
- """
21
- Read a single entry from the database by its hash_id.
22
-
23
- Args:
24
- hash_id (dict): Dictionary containing the hash_id to search for.
25
-
26
- Returns:
27
- Sources: The matching entry from the database, or None if no match is found.
28
- """
29
- with Session(engine) as session:
30
- statement = select(Sources).where(Sources.hash_id == hash_id)
31
- sources = session.exec(statement).first()
32
- return sources
33
-
34
-
35
- def add_one(data: dict):
36
- """
37
- Add a single entry to the database.
38
-
39
- Args:
40
- data (dict): Dictionary containing the data for the new entry.
41
-
42
- Returns:
43
- Sources: The added entry, or None if the entry already exists.
44
- """
45
- with Session(engine) as session:
46
- if session.exec(
47
- select(Sources).where(Sources.hash_id == data.get("hash_id"))
48
- ).first():
49
- logger.warning(f"Item with hash_id {data.get('hash_id')} already exists")
50
- return None # or raise an exception, or handle as needed
51
- sources = Sources(**data)
52
- session.add(sources)
53
- session.commit()
54
- session.refresh(sources)
55
- logger.info(f"Item with hash_id {data.get('hash_id')} added to the database")
56
- return sources
57
-
58
-
59
- def update_one(hash_id: dict, data: dict):
60
- """
61
- Update a single entry in the database by its hash_id.
62
-
63
- Args:
64
- hash_id (dict): Dictionary containing the hash_id to search for.
65
- data (dict): Dictionary containing the updated data for the entry.
66
-
67
- Returns:
68
- Sources: The updated entry, or None if no match is found.
69
- """
70
- with Session(engine) as session:
71
- # Check if the item with the given hash_id exists
72
- sources = session.exec(
73
- select(Sources).where(Sources.hash_id == hash_id)
74
- ).first()
75
- if not sources:
76
- logger.warning(f"No item with hash_id {hash_id} found for update")
77
- return None # or raise an exception, or handle as needed
78
- for key, value in data.items():
79
- setattr(sources, key, value)
80
- session.commit()
81
- logger.info(f"Item with hash_id {hash_id} updated in the database")
82
- return sources
83
-
84
-
85
- def delete_one(id: int):
86
- """
87
- Delete a single entry from the database by its id.
88
-
89
- Args:
90
- id (int): The id of the entry to delete.
91
-
92
- Returns:
93
- None
94
- """
95
- with Session(engine) as session:
96
- # Check if the item with the given hash_id exists
97
- sources = session.exec(
98
- select(Sources).where(Sources.hash_id == id)
99
- ).first()
100
- if not sources:
101
- logger.warning(f"No item with hash_id {id} found for deletion")
102
- return None # or raise an exception, or handle as needed
103
- session.delete(sources)
104
- session.commit()
105
- logger.info(f"Item with hash_id {id} deleted from the database")
106
-
107
-
108
- def add_many(data: list):
109
- """
110
- Add multiple entries to the database.
111
-
112
- Args:
113
- data (list): List of dictionaries, each containing the data for a new entry.
114
-
115
- Returns:
116
- None
117
- """
118
- with Session(engine) as session:
119
- for info in data:
120
- # Reuse add_one function for each item
121
- result = add_one(info)
122
- if result is None:
123
- logger.warning(
124
- f"Item with hash_id {info.get('hash_id')} could not be added"
125
- )
126
- else:
127
- logger.info(
128
- f"Item with hash_id {info.get('hash_id')} added to the database"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  )
130
- session.commit() # Commit at the end of the loop
 
131
 
132
 
133
- def delete_many(ids: list):
134
- """
135
- Delete multiple entries from the database by their ids.
136
-
137
- Args:
138
- ids (list): List of ids of the entries to delete.
139
 
140
- Returns:
141
- None
142
- """
143
- with Session(engine) as session:
144
- for id in ids:
145
- # Reuse delete_one function for each item
146
- result = delete_one(id)
147
- if result is None:
148
- logger.warning(f"No item with hash_id {id} found for deletion")
149
- else:
150
- logger.info(f"Item with hash_id {id} deleted from the database")
151
- session.commit() # Commit at the end of the loop
152
-
153
-
154
- def read_all(query: dict = None):
155
- """
156
- Read all entries from the database, optionally filtered by a query.
157
-
158
- Args:
159
- query (dict, optional): Dictionary containing the query parameters. Defaults to None.
160
-
161
- Returns:
162
- list: List of matching entries from the database.
163
- """
164
- with Session(engine) as session:
165
- statement = select(Sources)
166
- if query:
167
- statement = statement.where(
168
- *[getattr(Sources, key) == value for key, value in query.items()]
169
- )
170
- sources = session.exec(statement).all()
171
- return sources
172
-
173
-
174
- def delete_all():
175
- """
176
- Delete all entries from the database.
177
-
178
- Returns:
179
- None
180
- """
181
- with Session(engine) as session:
182
- session.exec(Sources).delete()
183
- session.commit()
184
- logger.info("All items deleted from the database")
 
3
  from rag_app.utils.logger import get_console_logger
4
  import os
5
  from dotenv import load_dotenv
6
+ import uuid
7
+ from datetime import datetime
8
+
9
+
10
+ class DataBaseHandler():
11
+ """
12
+ A class for managing the database.
13
+
14
+ Attributes:
15
+ sqlite_file_name (str): The SQLite file name for the database.
16
+ logger (Logger): The logger for logging database operations.
17
+ engine (Engine): The SQLAlchemy engine for the database.
18
+
19
+ Methods:
20
+ create_all_tables: Create all tables in the database.
21
+ read_one: Read a single entry from the database by its hash_id.
22
+ add_one: Add a single entry to the database.
23
+ update_one: Update a single entry in the database by its hash_id.
24
+ delete_one: Delete a single entry from the database by its id.
25
+ add_many: Add multiple entries to the database.
26
+ delete_many: Delete multiple entries from the database by their ids.
27
+ read_all: Read all entries from the database, optionally filtered by a query.
28
+ delete_all: Delete all entries from the database.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ sqlite_file_name = os.getenv('SOURCES_CACHE'),
34
+ logger = get_console_logger("db_handler"),
35
+ # *args,
36
+ # **kwargs,
37
+ ):
38
+ self.sqlite_file_name = sqlite_file_name
39
+ self.logger = logger
40
+
41
+ sqlite_url = f"sqlite:///{self.sqlite_file_name}"
42
+ self.engine = create_engine(sqlite_url, echo=False)
43
+
44
+
45
+ self.session_id = str(uuid.uuid4())
46
+ self.session_date_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
47
+
48
+ def create_all_tables(self) -> None:
49
+ SQLModel.metadata.create_all(self.engine)
50
+
51
+ def create_new_session(self) -> None:
52
+ """creates a new session_id and date time
53
+
54
+ """
55
+ self.session_id = str(uuid.uuid4())
56
+ self.session_date_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
57
+
58
+
59
+ def read_one(self,hash_id: dict):
60
+ """
61
+ Read a single entry from the database by its hash_id.
62
+
63
+ Args:
64
+ hash_id (dict): Dictionary containing the hash_id to search for.
65
+
66
+ Returns:
67
+ Sources: The matching entry from the database, or None if no match is found.
68
+ """
69
+ with Session(self.engine) as session:
70
+ statement = select(Sources).where(Sources.hash_id == hash_id)
71
+ sources = session.exec(statement).first()
72
+ return sources
73
+
74
+
75
+ def add_one(self,data: dict):
76
+ """
77
+ Add a single entry to the database.
78
+
79
+ Args:
80
+ data (dict): Dictionary containing the data for the new entry.
81
+
82
+ Returns:
83
+ Sources: The added entry, or None if the entry already exists.
84
+ """
85
+ with Session(self.engine) as session:
86
+ if session.exec(
87
+ select(Sources).where(Sources.hash_id == data.get("hash_id"))
88
+ ).first():
89
+ self.logger.warning(f"Item with hash_id {data.get('hash_id')} already exists")
90
+ return None # or raise an exception, or handle as needed
91
+ sources = Sources(**data)
92
+ session.add(sources)
93
+ session.commit()
94
+ session.refresh(sources)
95
+ self.logger.info(f"Item with hash_id {data.get('hash_id')} added to the database")
96
+ return sources
97
+
98
+
99
+ def update_one(self,hash_id: dict, data: dict):
100
+ """
101
+ Update a single entry in the database by its hash_id.
102
+
103
+ Args:
104
+ hash_id (dict): Dictionary containing the hash_id to search for.
105
+ data (dict): Dictionary containing the updated data for the entry.
106
+
107
+ Returns:
108
+ Sources: The updated entry, or None if no match is found.
109
+ """
110
+ with Session(self.engine) as session:
111
+ # Check if the item with the given hash_id exists
112
+ sources = session.exec(
113
+ select(Sources).where(Sources.hash_id == hash_id)
114
+ ).first()
115
+ if not sources:
116
+ self.logger.warning(f"No item with hash_id {hash_id} found for update")
117
+ return None # or raise an exception, or handle as needed
118
+ for key, value in data.items():
119
+ setattr(sources, key, value)
120
+ session.commit()
121
+ self.logger.info(f"Item with hash_id {hash_id} updated in the database")
122
+ return sources
123
+
124
+
125
+ def delete_one(self,id: int):
126
+ """
127
+ Delete a single entry from the database by its id.
128
+
129
+ Args:
130
+ id (int): The id of the entry to delete.
131
+
132
+ Returns:
133
+ None
134
+ """
135
+ with Session(self.engine) as session:
136
+ # Check if the item with the given hash_id exists
137
+ sources = session.exec(
138
+ select(Sources).where(Sources.hash_id == id)
139
+ ).first()
140
+ if not sources:
141
+ self.logger.warning(f"No item with hash_id {id} found for deletion")
142
+ return None # or raise an exception, or handle as needed
143
+ session.delete(sources)
144
+ session.commit()
145
+ self.logger.info(f"Item with hash_id {id} deleted from the database")
146
+
147
+
148
+ def add_many(self,data: list):
149
+ """
150
+ Add multiple entries to the database.
151
+
152
+ Args:
153
+ data (list): List of dictionaries, each containing the data for a new entry.
154
+
155
+ Returns:
156
+ None
157
+ """
158
+ with Session(self.engine) as session:
159
+ for info in data:
160
+ # Reuse add_one function for each item
161
+ result = self.add_one(info)
162
+ if result is None:
163
+ self.logger.warning(
164
+ f"Item with hash_id {info.get('hash_id')} could not be added"
165
+ )
166
+ else:
167
+ self.logger.info(
168
+ f"Item with hash_id {info.get('hash_id')} added to the database"
169
+ )
170
+ session.commit() # Commit at the end of the loop
171
+
172
+
173
+ def delete_many(self,ids: list):
174
+ """
175
+ Delete multiple entries from the database by their ids.
176
+
177
+ Args:
178
+ ids (list): List of ids of the entries to delete.
179
+
180
+ Returns:
181
+ None
182
+ """
183
+ with Session(self.engine) as session:
184
+ for id in ids:
185
+ # Reuse delete_one function for each item
186
+ result = self.delete_one(id)
187
+ if result is None:
188
+ self.logger.warning(f"No item with hash_id {id} found for deletion")
189
+ else:
190
+ self.logger.info(f"Item with hash_id {id} deleted from the database")
191
+ session.commit() # Commit at the end of the loop
192
+
193
+
194
+ def read_all(self,query: dict = None):
195
+ """
196
+ Read all entries from the database, optionally filtered by a query.
197
+
198
+ Args:
199
+ query (dict, optional): Dictionary containing the query parameters. Defaults to None.
200
+
201
+ Returns:
202
+ list: List of matching entries from the database.
203
+ """
204
+ with Session(self.engine) as session:
205
+ statement = select(Sources)
206
+ if query:
207
+ statement = statement.where(
208
+ *[getattr(Sources, key) == value for key, value in query.items()]
209
  )
210
+ sources = session.exec(statement).all()
211
+ return sources
212
 
213
 
214
+ def delete_all(self,):
215
+ """
216
+ Delete all entries from the database.
 
 
 
217
 
218
+ Returns:
219
+ None
220
+ """
221
+ with Session(self.engine) as session:
222
+ session.exec(Sources).delete()
223
+ session.commit()
224
+ self.logger.info("All items deleted from the database")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_app/database/schema.py CHANGED
@@ -5,7 +5,7 @@ import datetime
5
  class Sources(SQLModel, table=True):
6
  """
7
  Database schema for the Sources table.
8
-
9
  Attributes:
10
  id (Optional[int]): The primary key for the table.
11
  url (str): The URL of the source.
@@ -14,6 +14,8 @@ class Sources(SQLModel, table=True):
14
  created_at (float): Timestamp indicating when the entry was created.
15
  summary (str): A summary of the source content.
16
  embedded (bool): Flag indicating whether the source is embedded.
 
 
17
  """
18
  id: Optional[int] = Field(default=None, primary_key=True)
19
  url: str = Field()
@@ -22,5 +24,7 @@ class Sources(SQLModel, table=True):
22
  created_at: float = Field(default=datetime.datetime.now().timestamp())
23
  summary: str = Field(default="")
24
  embedded: bool = Field(default=False)
 
 
25
 
26
  __table_args__ = {"extend_existing": True}
 
5
  class Sources(SQLModel, table=True):
6
  """
7
  Database schema for the Sources table.
8
+
9
  Attributes:
10
  id (Optional[int]): The primary key for the table.
11
  url (str): The URL of the source.
 
14
  created_at (float): Timestamp indicating when the entry was created.
15
  summary (str): A summary of the source content.
16
  embedded (bool): Flag indicating whether the source is embedded.
17
+ session_id (str): A unique identifier for the session when the entry was added.
18
+ session_date_time (str): The timestamp when the session was created.
19
  """
20
  id: Optional[int] = Field(default=None, primary_key=True)
21
  url: str = Field()
 
24
  created_at: float = Field(default=datetime.datetime.now().timestamp())
25
  summary: str = Field(default="")
26
  embedded: bool = Field(default=False)
27
+ session_id: str = Field(default="")
28
+ session_date_time: str = Field(default="")
29
 
30
  __table_args__ = {"extend_existing": True}
rag_app/structured_tools/agent_tools.py CHANGED
@@ -20,12 +20,12 @@ def web_research(query: str) -> List[dict]:
20
  def ask_user(query: str) -> str:
21
  """Frage den Benutzer direkt wenn du nicht sicher bist was er meint oder du eine Entscheidung brauchst."""
22
 
23
- result = HumanInputRun.invoke(query)
24
  return result
25
 
26
  @tool
27
  def get_email(query: str) -> str:
28
  """Frage den Benutzer nach seiner EMail Adresse, wenn du denkst du hast seine Anfrage beantwortet hast, damit wir ihm mehr Informationen im Anschluss zu senden kannst."""
29
 
30
- result = HumanInputRun.invoke(query)
31
  return result
 
20
  def ask_user(query: str) -> str:
21
  """Frage den Benutzer direkt wenn du nicht sicher bist was er meint oder du eine Entscheidung brauchst."""
22
 
23
+ result = HumanInputRun().invoke(query)
24
  return result
25
 
26
  @tool
27
  def get_email(query: str) -> str:
28
  """Frage den Benutzer nach seiner EMail Adresse, wenn du denkst du hast seine Anfrage beantwortet hast, damit wir ihm mehr Informationen im Anschluss zu senden kannst."""
29
 
30
+ result = HumanInputRun().invoke(query)
31
  return result
rag_app/structured_tools/structured_tools.py CHANGED
@@ -1,25 +1,17 @@
1
- from langchain.tools import BaseTool, StructuredTool, tool
2
- from langchain_community.tools import WikipediaQueryRun
3
- from langchain_community.utilities import WikipediaAPIWrapper
4
- #from langchain.tools import Tool
5
  from langchain_google_community import GoogleSearchAPIWrapper
6
  from langchain_community.embeddings.sentence_transformer import (
7
  SentenceTransformerEmbeddings,
8
  )
9
  from langchain_community.vectorstores import Chroma
10
- import ast
11
-
12
- import chromadb
13
-
14
  from rag_app.utils.utils import (
15
  parse_list_to_dicts, format_search_results
16
  )
17
- from rag_app.database.db_handler import (
18
- add_many
19
- )
20
-
21
  import os
22
- # from innovation_pathfinder_ai.utils import create_wikipedia_urls_from_text
 
 
23
 
24
  persist_directory = os.getenv('VECTOR_DATABASE_LOCATION')
25
  embedding_model = os.getenv("EMBEDDING_MODEL")
@@ -49,6 +41,7 @@ def memory_search(query:str) -> str:
49
  retriever = vector_db.as_retriever()
50
  docs = retriever.invoke(query)
51
 
 
52
  return docs.__str__()
53
 
54
  @tool
@@ -91,8 +84,8 @@ def google_search(query: str) -> str:
91
  if len(search_results)>1:
92
  cleaner_sources =format_search_results(search_results)
93
  parsed_csources = parse_list_to_dicts(cleaner_sources)
94
- add_many(parsed_csources)
95
  else:
96
  cleaner_sources = search_results
97
 
98
- return cleaner_sources.__str__()
 
1
+ from langchain.tools import tool
 
 
 
2
  from langchain_google_community import GoogleSearchAPIWrapper
3
  from langchain_community.embeddings.sentence_transformer import (
4
  SentenceTransformerEmbeddings,
5
  )
6
  from langchain_community.vectorstores import Chroma
 
 
 
 
7
  from rag_app.utils.utils import (
8
  parse_list_to_dicts, format_search_results
9
  )
10
+ import chromadb
 
 
 
11
  import os
12
+ from config import db
13
+
14
+
15
 
16
  persist_directory = os.getenv('VECTOR_DATABASE_LOCATION')
17
  embedding_model = os.getenv("EMBEDDING_MODEL")
 
41
  retriever = vector_db.as_retriever()
42
  docs = retriever.invoke(query)
43
 
44
+
45
  return docs.__str__()
46
 
47
  @tool
 
84
  if len(search_results)>1:
85
  cleaner_sources =format_search_results(search_results)
86
  parsed_csources = parse_list_to_dicts(cleaner_sources)
87
+ db.add_many(parsed_csources)
88
  else:
89
  cleaner_sources = search_results
90
 
91
+ return cleaner_sources.__str__()