Spaces:
Running
Running
improve answer length with mistral, avoid changing model after selection
Browse files- document_qa_engine.py +2 -2
- streamlit_app.py +15 -15
document_qa_engine.py
CHANGED
@@ -200,8 +200,8 @@ class DocumentQAEngine:
|
|
200 |
|
201 |
return texts, metadatas, ids
|
202 |
|
203 |
-
def create_memory_embeddings(self, pdf_path, doc_id=None):
|
204 |
-
texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=
|
205 |
if doc_id:
|
206 |
hash = doc_id
|
207 |
else:
|
|
|
200 |
|
201 |
return texts, metadatas, ids
|
202 |
|
203 |
+
def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
|
204 |
+
texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=chunk_size, perc_overlap=perc_overlap)
|
205 |
if doc_id:
|
206 |
hash = doc_id
|
207 |
else:
|
streamlit_app.py
CHANGED
@@ -45,21 +45,18 @@ def new_file():
|
|
45 |
|
46 |
|
47 |
@st.cache_resource
|
48 |
-
def init_qa(
|
49 |
if model == 'chatgpt-3.5-turbo':
|
50 |
chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
|
51 |
temperature=0,
|
52 |
return_pl_id=True,
|
53 |
-
pl_tags=["streamlit", "chatgpt"]
|
54 |
-
|
55 |
-
embeddings = OpenAIEmbeddings(openai_api_key=api_key)
|
56 |
elif model == 'mistral-7b-instruct-v0.1':
|
57 |
chat = HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.1",
|
58 |
-
model_kwargs={"temperature": 0.01}
|
59 |
-
api_key=api_key)
|
60 |
embeddings = HuggingFaceEmbeddings(
|
61 |
-
model_name="all-MiniLM-L6-v2"
|
62 |
-
api_key=api_key)
|
63 |
|
64 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
65 |
|
@@ -85,6 +82,7 @@ def play_old_messages():
|
|
85 |
else:
|
86 |
st.write(message['content'])
|
87 |
|
|
|
88 |
|
89 |
model = st.sidebar.radio("Model", ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"),
|
90 |
index=1,
|
@@ -92,20 +90,22 @@ model = st.sidebar.radio("Model", ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.
|
|
92 |
"ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
|
93 |
"Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
|
94 |
],
|
95 |
-
help="Select the model you want to use."
|
|
|
96 |
|
97 |
-
is_api_key_provided = False
|
98 |
if not st.session_state['api_key']:
|
99 |
if model == 'mistral-7b-instruct-v0.1':
|
100 |
-
api_key = st.sidebar.text_input('Huggingface API Key')
|
101 |
if api_key:
|
102 |
st.session_state['api_key'] = is_api_key_provided = True
|
103 |
-
|
|
|
104 |
elif model == 'chatgpt-3.5-turbo':
|
105 |
-
api_key = st.sidebar.text_input('OpenAI API Key')
|
106 |
if api_key:
|
107 |
st.session_state['api_key'] = is_api_key_provided = True
|
108 |
-
|
|
|
109 |
else:
|
110 |
is_api_key_provided = st.session_state['api_key']
|
111 |
|
@@ -158,7 +158,7 @@ if uploaded_file and not st.session_state.loaded_embeddings:
|
|
158 |
tmp_file = NamedTemporaryFile()
|
159 |
tmp_file.write(bytearray(binary))
|
160 |
# hash = get_file_hash(tmp_file.name)[:10]
|
161 |
-
st.session_state['doc_id'] = hash = st.session_state['rqa'].create_memory_embeddings(tmp_file.name)
|
162 |
st.session_state['loaded_embeddings'] = True
|
163 |
|
164 |
# timestamp = datetime.utcnow()
|
|
|
45 |
|
46 |
|
47 |
@st.cache_resource
|
48 |
+
def init_qa(model):
|
49 |
if model == 'chatgpt-3.5-turbo':
|
50 |
chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
|
51 |
temperature=0,
|
52 |
return_pl_id=True,
|
53 |
+
pl_tags=["streamlit", "chatgpt"])
|
54 |
+
embeddings = OpenAIEmbeddings()
|
|
|
55 |
elif model == 'mistral-7b-instruct-v0.1':
|
56 |
chat = HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.1",
|
57 |
+
model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
|
|
|
58 |
embeddings = HuggingFaceEmbeddings(
|
59 |
+
model_name="all-MiniLM-L6-v2")
|
|
|
60 |
|
61 |
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
|
62 |
|
|
|
82 |
else:
|
83 |
st.write(message['content'])
|
84 |
|
85 |
+
is_api_key_provided = st.session_state['api_key']
|
86 |
|
87 |
model = st.sidebar.radio("Model", ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"),
|
88 |
index=1,
|
|
|
90 |
"ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
|
91 |
"Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
|
92 |
],
|
93 |
+
help="Select the model you want to use.",
|
94 |
+
disabled=is_api_key_provided)
|
95 |
|
|
|
96 |
if not st.session_state['api_key']:
|
97 |
if model == 'mistral-7b-instruct-v0.1':
|
98 |
+
api_key = st.sidebar.text_input('Huggingface API Key') if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
|
99 |
if api_key:
|
100 |
st.session_state['api_key'] = is_api_key_provided = True
|
101 |
+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
|
102 |
+
st.session_state['rqa'] = init_qa(model)
|
103 |
elif model == 'chatgpt-3.5-turbo':
|
104 |
+
api_key = st.sidebar.text_input('OpenAI API Key') if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
|
105 |
if api_key:
|
106 |
st.session_state['api_key'] = is_api_key_provided = True
|
107 |
+
os.environ['OPENAI_API_KEY'] = api_key
|
108 |
+
st.session_state['rqa'] = init_qa(model)
|
109 |
else:
|
110 |
is_api_key_provided = st.session_state['api_key']
|
111 |
|
|
|
158 |
tmp_file = NamedTemporaryFile()
|
159 |
tmp_file.write(bytearray(binary))
|
160 |
# hash = get_file_hash(tmp_file.name)[:10]
|
161 |
+
st.session_state['doc_id'] = hash = st.session_state['rqa'].create_memory_embeddings(tmp_file.name, chunk_size=250, perc_overlap=0.1)
|
162 |
st.session_state['loaded_embeddings'] = True
|
163 |
|
164 |
# timestamp = datetime.utcnow()
|