IlyaGusev commited on
Commit
c8a296f
1 Parent(s): eaf0bb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -16
app.py CHANGED
@@ -18,10 +18,7 @@ from langchain.document_loaders import (
18
  UnstructuredWordDocumentLoader,
19
  )
20
  from langchain.text_splitter import RecursiveCharacterTextSplitter
21
- from langchain.vectorstores import Chroma
22
- from langchain.embeddings import HuggingFaceEmbeddings
23
  from langchain.docstore.document import Document
24
- from chromadb.config import Settings
25
  from llama_cpp import Llama
26
 
27
 
@@ -66,9 +63,9 @@ def load_model(
66
  print("Model loaded!")
67
  return model
68
 
 
69
  MAX_NEW_TOKENS = 1500
70
- EMBEDDER_NAME = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
71
- EMBEDDER = HuggingFaceEmbeddings(model_name=EMBEDDER_NAME)
72
  MODEL = load_model()
73
 
74
 
@@ -121,15 +118,12 @@ def build_index(file_paths, db, chunk_size, chunk_overlap, file_warning):
121
  continue
122
  fixed_documents.append(doc)
123
  print("Documents after processing:", len(fixed_documents))
124
-
125
- db = Chroma.from_documents(
126
- fixed_documents,
127
- EMBEDDER,
128
- client_settings=Settings(
129
- anonymized_telemetry=False
130
- )
131
- )
132
  print("Embeddings calculated!")
 
133
  file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
134
  return db, file_warning
135
 
@@ -138,9 +132,11 @@ def retrieve(history, db, retrieved_docs, k_documents):
138
  retrieved_docs = ""
139
  if db:
140
  last_user_message = history[-1][0]
141
- retriever = db.as_retriever(search_kwargs={"k": k_documents})
142
- docs = retriever.get_relevant_documents(last_user_message)
143
- retrieved_docs = "\n\n".join([doc.page_content for doc in docs])
 
 
144
  return retrieved_docs
145
 
146
 
 
18
  UnstructuredWordDocumentLoader,
19
  )
20
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
21
  from langchain.docstore.document import Document
 
22
  from llama_cpp import Llama
23
 
24
 
 
63
  print("Model loaded!")
64
  return model
65
 
66
+
67
  MAX_NEW_TOKENS = 1500
68
+ EMBEDDER = SentenceTransformer("sentence-transformers/paraphrase-multilingual-mpnet-base-v2")
 
69
  MODEL = load_model()
70
 
71
 
 
118
  continue
119
  fixed_documents.append(doc)
120
  print("Documents after processing:", len(fixed_documents))
121
+
122
+ texts = [doc.page_content for doc in fixed_documents]
123
+ embeddings = EMBEDDER.encode(texts, convert_to_tensor=True)
124
+ db = {"docs": texts, "embeddings": embeddings}
 
 
 
 
125
  print("Embeddings calculated!")
126
+
127
  file_warning = f"Загружено {len(fixed_documents)} фрагментов! Можно задавать вопросы."
128
  return db, file_warning
129
 
 
132
  retrieved_docs = ""
133
  if db:
134
  last_user_message = history[-1][0]
135
+ query_embedding = EMBEDDER.encode(last_user_message, convert_to_tensor=True)
136
+ scores = cos_sim(query_embedding, db["embeddings"])[0]
137
+ top_k_idx = torch.topk(scores, k=k_documents)[1]
138
+ top_k_documents = [db["docs"][idx] for idx in top_k_idx]
139
+ retrieved_docs = "\n\n".join(top_k_documents)
140
  return retrieved_docs
141
 
142