Sergidev commited on
Commit
8e9f89e
1 Parent(s): 487e8b7

Update modules/pmbl.py

Browse files
Files changed (1) hide show
  1. modules/pmbl.py +64 -65
modules/pmbl.py CHANGED
@@ -1,56 +1,54 @@
1
- import sqlite3
2
  from datetime import datetime
3
  from ctransformers import AutoModelForCausalLM
4
- from concurrent.futures import ThreadPoolExecutor
5
 
6
  class PMBL:
7
  def __init__(self, model_path, gpu_layers=50):
8
  self.model_path = model_path
9
  self.gpu_layers = gpu_layers
10
- self.init_db()
11
- self.executor = ThreadPoolExecutor(max_workers=6)
12
-
13
- def init_db(self):
14
- conn = sqlite3.connect('chat_history.db')
15
- c = conn.cursor()
16
- c.execute('''CREATE TABLE IF NOT EXISTS chats
17
- (id INTEGER PRIMARY KEY AUTOINCREMENT,
18
- timestamp TEXT,
19
- prompt TEXT,
20
- response TEXT,
21
- topic TEXT)''')
22
- conn.commit()
23
- conn.close()
24
-
25
- def get_chat_history(self, mode="full", user_message=""):
26
- conn = sqlite3.connect('chat_history.db')
27
- c = conn.cursor()
28
-
29
- if mode == "full":
30
- c.execute("SELECT prompt, response FROM chats ORDER BY id")
31
- history = []
32
- for row in c.fetchall():
33
- history.append({"role": "user", "content": row[0]})
34
- history.append({"role": "PMB", "content": row[1]})
35
- else:
36
- c.execute("SELECT id, prompt, response FROM chats WHERE topic != 'Untitled'")
37
- chats = c.fetchall()
38
- relevant_chat_id = self.find_relevant_chat(chats, user_message)
39
-
40
- if relevant_chat_id:
41
- c.execute("SELECT prompt, response FROM chats WHERE id = ?", (relevant_chat_id,))
42
- row = c.fetchone()
43
- history = [
44
- {"role": "user", "content": row[0]},
45
- {"role": "PMB", "content": row[1]}
46
- ]
47
  else:
48
- history = []
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- conn.close()
51
  return history
52
 
53
- def find_relevant_chat(self, chats, user_message):
54
  max_score = 0
55
  relevant_chat_id = None
56
 
@@ -75,15 +73,14 @@ class PMBL:
75
 
76
  return score
77
 
78
- def save_chat_history(self, prompt, response):
79
- conn = sqlite3.connect('chat_history.db')
80
- c = conn.cursor()
81
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
82
- c.execute("INSERT INTO chats (timestamp, prompt, response, topic) VALUES (?, ?, ?, 'Untitled')", (timestamp, prompt, response))
83
- conn.commit()
84
- conn.close()
85
 
86
- def generate_response(self, prompt, history, mode):
87
  history.append({"role": "user", "content": prompt})
88
 
89
  formatted_history = ""
@@ -117,7 +114,7 @@ class PMBL:
117
  response_text += chunk
118
  yield chunk
119
 
120
- self.save_chat_history(prompt, response_text)
121
 
122
  def calculate_context(self, system_prompt, formatted_history):
123
  system_prompt_tokens = len(system_prompt) // 4
@@ -131,21 +128,19 @@ class PMBL:
131
  else:
132
  return context_ceiling
133
 
134
- def sleep_mode(self):
135
- conn = sqlite3.connect('chat_history.db')
136
- c = conn.cursor()
137
- c.execute("SELECT id, prompt, response FROM chats WHERE topic = 'Untitled'")
138
- untitled_chats = c.fetchall()
139
 
140
- for chat in untitled_chats:
141
- chat_id, prompt, response = chat
142
- topic = self.generate_topic(prompt, response)
143
- c.execute("UPDATE chats SET topic = ? WHERE id = ?", (topic, chat_id))
144
- conn.commit()
145
-
146
- conn.close()
147
 
148
- def generate_topic(self, prompt, response):
149
  llm = AutoModelForCausalLM.from_pretrained(
150
  self.model_path,
151
  model_type="llama",
@@ -162,4 +157,8 @@ class PMBL:
162
  stop=["\n"]
163
  )
164
 
165
- return topic.strip()
 
 
 
 
 
1
+ import aiosqlite
2
  from datetime import datetime
3
  from ctransformers import AutoModelForCausalLM
4
+ import asyncio
5
 
6
  class PMBL:
7
  def __init__(self, model_path, gpu_layers=50):
8
  self.model_path = model_path
9
  self.gpu_layers = gpu_layers
10
+ self.db_name = 'chat_history.db'
11
+ self.init_db_lock = asyncio.Lock()
12
+
13
+ async def init_db(self):
14
+ async with self.init_db_lock:
15
+ async with aiosqlite.connect(self.db_name) as db:
16
+ await db.execute('''CREATE TABLE IF NOT EXISTS chats
17
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
18
+ timestamp TEXT,
19
+ prompt TEXT,
20
+ response TEXT,
21
+ topic TEXT)''')
22
+ await db.commit()
23
+
24
+ async def get_chat_history(self, mode="full", user_message=""):
25
+ await self.init_db()
26
+ async with aiosqlite.connect(self.db_name) as db:
27
+ if mode == "full":
28
+ async with db.execute("SELECT prompt, response FROM chats ORDER BY id") as cursor:
29
+ rows = await cursor.fetchall()
30
+ history = []
31
+ for row in rows:
32
+ history.append({"role": "user", "content": row[0]})
33
+ history.append({"role": "PMB", "content": row[1]})
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  else:
35
+ async with db.execute("SELECT id, prompt, response FROM chats WHERE topic != 'Untitled'") as cursor:
36
+ chats = await cursor.fetchall()
37
+ relevant_chat_id = await self.find_relevant_chat(chats, user_message)
38
+
39
+ if relevant_chat_id:
40
+ async with db.execute("SELECT prompt, response FROM chats WHERE id = ?", (relevant_chat_id,)) as cursor:
41
+ row = await cursor.fetchone()
42
+ history = [
43
+ {"role": "user", "content": row[0]},
44
+ {"role": "PMB", "content": row[1]}
45
+ ]
46
+ else:
47
+ history = []
48
 
 
49
  return history
50
 
51
+ async def find_relevant_chat(self, chats, user_message):
52
  max_score = 0
53
  relevant_chat_id = None
54
 
 
73
 
74
  return score
75
 
76
+ async def save_chat_history(self, prompt, response):
77
+ await self.init_db()
78
+ async with aiosqlite.connect(self.db_name) as db:
79
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
80
+ await db.execute("INSERT INTO chats (timestamp, prompt, response, topic) VALUES (?, ?, ?, 'Untitled')", (timestamp, prompt, response))
81
+ await db.commit()
 
82
 
83
+ async def generate_response(self, prompt, history, mode):
84
  history.append({"role": "user", "content": prompt})
85
 
86
  formatted_history = ""
 
114
  response_text += chunk
115
  yield chunk
116
 
117
+ await self.save_chat_history(prompt, response_text)
118
 
119
  def calculate_context(self, system_prompt, formatted_history):
120
  system_prompt_tokens = len(system_prompt) // 4
 
128
  else:
129
  return context_ceiling
130
 
131
+ async def sleep_mode(self):
132
+ await self.init_db()
133
+ async with aiosqlite.connect(self.db_name) as db:
134
+ async with db.execute("SELECT id, prompt, response FROM chats WHERE topic = 'Untitled'") as cursor:
135
+ untitled_chats = await cursor.fetchall()
136
 
137
+ for chat in untitled_chats:
138
+ chat_id, prompt, response = chat
139
+ topic = await self.generate_topic(prompt, response)
140
+ await db.execute("UPDATE chats SET topic = ? WHERE id = ?", (topic, chat_id))
141
+ await db.commit()
 
 
142
 
143
+ async def generate_topic(self, prompt, response):
144
  llm = AutoModelForCausalLM.from_pretrained(
145
  self.model_path,
146
  model_type="llama",
 
157
  stop=["\n"]
158
  )
159
 
160
+ return topic.strip()
161
+
162
+ async def close(self):
163
+ # Implement any cleanup operations here if needed
164
+ pass