Update app.py
Browse files
app.py
CHANGED
@@ -18,10 +18,7 @@ from langchain.document_loaders import (
|
|
18 |
UnstructuredWordDocumentLoader,
|
19 |
)
|
20 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
21 |
-
from langchain.vectorstores import Chroma
|
22 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
23 |
from langchain.docstore.document import Document
|
24 |
-
from chromadb.config import Settings
|
25 |
from llama_cpp import Llama
|
26 |
|
27 |
|
@@ -66,9 +63,9 @@ def load_model(
|
|
66 |
print("Model loaded!")
|
67 |
return model
|
68 |
|
|
|
69 |
MAX_NEW_TOKENS = 1500
|
70 |
-
|
71 |
-
EMBEDDER = HuggingFaceEmbeddings(model_name=EMBEDDER_NAME)
|
72 |
MODEL = load_model()
|
73 |
|
74 |
|
@@ -121,15 +118,12 @@ def build_index(file_paths, db, chunk_size, chunk_overlap, file_warning):
|
|
121 |
continue
|
122 |
fixed_documents.append(doc)
|
123 |
print("Documents after processing:", len(fixed_documents))
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
client_settings=Settings(
|
129 |
-
anonymized_telemetry=False
|
130 |
-
)
|
131 |
-
)
|
132 |
print("Embeddings calculated!")
|
|
|
133 |
file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
|
134 |
return db, file_warning
|
135 |
|
@@ -138,9 +132,11 @@ def retrieve(history, db, retrieved_docs, k_documents):
|
|
138 |
retrieved_docs = ""
|
139 |
if db:
|
140 |
last_user_message = history[-1][0]
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
|
|
144 |
return retrieved_docs
|
145 |
|
146 |
|
|
|
18 |
UnstructuredWordDocumentLoader,
|
19 |
)
|
20 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
|
|
21 |
from langchain.docstore.document import Document
|
|
|
22 |
from llama_cpp import Llama
|
23 |
|
24 |
|
|
|
63 |
print("Model loaded!")
|
64 |
return model
|
65 |
|
66 |
+
|
67 |
MAX_NEW_TOKENS = 1500
|
68 |
+
EMBEDDER = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
|
|
|
69 |
MODEL = load_model()
|
70 |
|
71 |
|
|
|
118 |
continue
|
119 |
fixed_documents.append(doc)
|
120 |
print("Documents after processing:", len(fixed_documents))
|
121 |
+
|
122 |
+
texts = [doc.page_content for doc in fixed_documents]
|
123 |
+
embeddings = EMBEDDER.encode(texts, convert_to_tensor=True)
|
124 |
+
db = {"docs": texts, "embeddings": embeddings}
|
|
|
|
|
|
|
|
|
125 |
print("Embeddings calculated!")
|
126 |
+
|
127 |
file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
|
128 |
return db, file_warning
|
129 |
|
|
|
132 |
retrieved_docs = ""
|
133 |
if db:
|
134 |
last_user_message = history[-1][0]
|
135 |
+
query_embedding = EMBEDDER.encode(last_user_message, convert_to_tensor=True)
|
136 |
+
scores = cos_sim(query_embedding, db["embeddings"])[0]
|
137 |
+
top_k_idx = torch.topk(scores, k=k_documents)[1]
|
138 |
+
top_k_documents = [db["docs"][idx] for idx in top_k_idx]
|
139 |
+
retrieved_docs = "\n\n".join(top_k_documents)
|
140 |
return retrieved_docs
|
141 |
|
142 |
|