Spaces:
Running
Running
temp before HF pull
Browse files- .devcontainer/devcontainer.json +29 -0
- .gitattributes +1 -0
- .gitignore +179 -0
- .streamlit/config.toml +11 -0
- .vscode/launch.json +13 -0
- README.md +28 -0
- app.py +755 -0
- backend.py +185 -0
- class_templates.py +46 -0
- finetune_backend.py +55 -0
- helpers.py +112 -0
- llama_test.ipynb +80 -0
- openai_interface.py +100 -0
- preprocessing.py +123 -0
- prompt_templates.py +63 -0
- prompt_templates_luis.py +63 -0
- readme2.md +43 -0
- requirements.txt +50 -0
- reranker.py +89 -0
- retrieval_evaluation.py +332 -0
- unitesting_utils.py +34 -0
- utilities/install_kernel.sh +4 -0
- weaviate_interface.py +434 -0
.devcontainer/devcontainer.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "Python 3",
|
3 |
+
"image": "mcr.microsoft.com/devcontainers/python:3.10",
|
4 |
+
|
5 |
+
// Features to add to the dev container. More info: https://containers.dev/features.
|
6 |
+
//"features": {}
|
7 |
+
// Configure tool-specific properties.
|
8 |
+
"customizations": {
|
9 |
+
// Configure properties specific to VS Code.
|
10 |
+
"vscode": {
|
11 |
+
"settings": {"terminal.integrated.shell.linux": "/bin/bash"},
|
12 |
+
"extensions": [
|
13 |
+
"ms-toolsai.jupyter"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
},
|
17 |
+
"forwardPorts": [8501, 8888],
|
18 |
+
"portsAttributes": {
|
19 |
+
"8501": {
|
20 |
+
"label": "Streamlit App",
|
21 |
+
"onAutoForward": "openBrowser"
|
22 |
+
},
|
23 |
+
"8888": {
|
24 |
+
"label": "Jupyter Notebook",
|
25 |
+
"onAutoForward": "openBrowser"
|
26 |
+
}
|
27 |
+
},
|
28 |
+
"postCreateCommand": "pip install -r requirements.txt"
|
29 |
+
}
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/impact_theory_data.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Large Files
|
2 |
+
# models/
|
3 |
+
eval_results/
|
4 |
+
# models/all-mpnet*
|
5 |
+
# models/finetuned-all-MiniLM*
|
6 |
+
# models/finetuned-WhereIsAI-UAE*
|
7 |
+
models/*
|
8 |
+
# !models/finetuned-all-mpnet-base-v2-300
|
9 |
+
|
10 |
+
data/*.parquet
|
11 |
+
.DS_Store
|
12 |
+
secrets.toml
|
13 |
+
TODO.md
|
14 |
+
|
15 |
+
assets/*
|
16 |
+
!assets/it_tom_bilyeu.png
|
17 |
+
|
18 |
+
# Byte-compiled / optimized / DLL files
|
19 |
+
__pycache__/
|
20 |
+
*.py[cod]
|
21 |
+
*$py.class
|
22 |
+
*copy*
|
23 |
+
# C extensions
|
24 |
+
*.so
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
# Distribution / packaging
|
29 |
+
.Python
|
30 |
+
build/
|
31 |
+
develop-eggs/
|
32 |
+
dist/
|
33 |
+
downloads/
|
34 |
+
eggs/
|
35 |
+
.eggs/
|
36 |
+
lib/
|
37 |
+
lib64/
|
38 |
+
parts/
|
39 |
+
sdist/
|
40 |
+
var/
|
41 |
+
wheels/
|
42 |
+
share/python-wheels/
|
43 |
+
*.egg-info/
|
44 |
+
.installed.cfg
|
45 |
+
*.egg
|
46 |
+
MANIFEST
|
47 |
+
|
48 |
+
# PyInstaller
|
49 |
+
# Usually these files are written by a python script from a template
|
50 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
51 |
+
*.manifest
|
52 |
+
*.spec
|
53 |
+
|
54 |
+
# Installer logs
|
55 |
+
pip-log.txt
|
56 |
+
pip-delete-this-directory.txt
|
57 |
+
|
58 |
+
# Unit test / coverage reports
|
59 |
+
htmlcov/
|
60 |
+
.tox/
|
61 |
+
.nox/
|
62 |
+
.coverage
|
63 |
+
.coverage.*
|
64 |
+
.cache
|
65 |
+
nosetests.xml
|
66 |
+
coverage.xml
|
67 |
+
*.cover
|
68 |
+
*.py,cover
|
69 |
+
.hypothesis/
|
70 |
+
.pytest_cache/
|
71 |
+
cover/
|
72 |
+
|
73 |
+
# Translations
|
74 |
+
*.mo
|
75 |
+
*.pot
|
76 |
+
|
77 |
+
# Django stuff:
|
78 |
+
*.log
|
79 |
+
local_settings.py
|
80 |
+
db.sqlite3
|
81 |
+
db.sqlite3-journal
|
82 |
+
|
83 |
+
# Flask stuff:
|
84 |
+
instance/
|
85 |
+
.webassets-cache
|
86 |
+
|
87 |
+
# Scrapy stuff:
|
88 |
+
.scrapy
|
89 |
+
|
90 |
+
# Sphinx documentation
|
91 |
+
docs/_build/
|
92 |
+
|
93 |
+
# PyBuilder
|
94 |
+
.pybuilder/
|
95 |
+
target/
|
96 |
+
|
97 |
+
# Jupyter Notebook
|
98 |
+
.ipynb_checkpoints
|
99 |
+
|
100 |
+
# IPython
|
101 |
+
profile_default/
|
102 |
+
ipython_config.py
|
103 |
+
|
104 |
+
# pyenv
|
105 |
+
# For a library or package, you might want to ignore these files since the code is
|
106 |
+
# intended to run in multiple environments; otherwise, check them in:
|
107 |
+
# .python-version
|
108 |
+
|
109 |
+
# pipenv
|
110 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
111 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
112 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
113 |
+
# install all needed dependencies.
|
114 |
+
#Pipfile.lock
|
115 |
+
|
116 |
+
# poetry
|
117 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
118 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
119 |
+
# commonly ignored for libraries.
|
120 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
121 |
+
#poetry.lock
|
122 |
+
|
123 |
+
# pdm
|
124 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
125 |
+
#pdm.lock
|
126 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
127 |
+
# in version control.
|
128 |
+
# https://pdm.fming.dev/#use-with-ide
|
129 |
+
.pdm.toml
|
130 |
+
|
131 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
132 |
+
__pypackages__/
|
133 |
+
|
134 |
+
# Celery stuff
|
135 |
+
celerybeat-schedule
|
136 |
+
celerybeat.pid
|
137 |
+
|
138 |
+
# SageMath parsed files
|
139 |
+
*.sage.py
|
140 |
+
|
141 |
+
# Environments
|
142 |
+
.env
|
143 |
+
.venv
|
144 |
+
env/
|
145 |
+
venv/
|
146 |
+
ENV/
|
147 |
+
env.bak/
|
148 |
+
venv.bak/
|
149 |
+
|
150 |
+
# Spyder project settings
|
151 |
+
.spyderproject
|
152 |
+
.spyproject
|
153 |
+
|
154 |
+
# Rope project settings
|
155 |
+
.ropeproject
|
156 |
+
|
157 |
+
# mkdocs documentation
|
158 |
+
/site
|
159 |
+
|
160 |
+
# mypy
|
161 |
+
.mypy_cache/
|
162 |
+
.dmypy.json
|
163 |
+
dmypy.json
|
164 |
+
|
165 |
+
# Pyre type checker
|
166 |
+
.pyre/
|
167 |
+
|
168 |
+
# pytype static type analyzer
|
169 |
+
.pytype/
|
170 |
+
|
171 |
+
# Cython debug symbols
|
172 |
+
cython_debug/
|
173 |
+
|
174 |
+
# PyCharm
|
175 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
176 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
177 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
178 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
179 |
+
#.idea/
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
base="dark"
|
3 |
+
primaryColor="purple" # border of textboxes !!??
|
4 |
+
#primaryColor="#2d59b3"
|
5 |
+
|
6 |
+
backgroundColor="#000000"
|
7 |
+
secondaryBackgroundColor= "#0e404d" # should be identical to blue in banner # "#2d59b3" light blue
|
8 |
+
textColor="#FFFFFF"
|
9 |
+
font="sans serif"
|
10 |
+
|
11 |
+
|
.vscode/launch.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "0.2.0",
|
3 |
+
"configurations": [
|
4 |
+
{
|
5 |
+
"name": "Python: Current File",
|
6 |
+
"type": "python",
|
7 |
+
"request": "launch",
|
8 |
+
"program": "${file}",
|
9 |
+
"console": "integratedTerminal",
|
10 |
+
"justMyCode": false
|
11 |
+
}
|
12 |
+
]
|
13 |
+
}
|
README.md
CHANGED
@@ -10,3 +10,31 @@ pinned: false
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
13 |
+
---
|
14 |
+
|
15 |
+
See the app @ [jpb-vectorsearch.streamlit.app](https://jpb-vectorsearch.streamlit.app/)
|
16 |
+
|
17 |
+
Beware, sometimes, the online app crashes... especially with the metrics.
|
18 |
+
|
19 |
+
<p align="left">
|
20 |
+
<img src="assets/screenshot_frontpage_with_finetune.png" width=800/>
|
21 |
+
</p>
|
22 |
+
|
23 |
+
<!-- <p align="center">
|
24 |
+
<img src="assets/screenshot_frontpage_online.png"/>
|
25 |
+
</p> -->
|
26 |
+
|
27 |
+
## Activity on Modal backend during finetuning
|
28 |
+
|
29 |
+
<p align="left">
|
30 |
+
<img src="assets/modal_finetuning1.png" width=800/>
|
31 |
+
</p>
|
32 |
+
|
33 |
+
<p align="left">
|
34 |
+
<img src="assets/modal_finetuning2.png" width=800/>
|
35 |
+
</p>
|
36 |
+
|
37 |
+
<p align="left">
|
38 |
+
<img src="assets/modal_finetuning_activity.png" width=800/>
|
39 |
+
</p>
|
40 |
+
|
app.py
ADDED
@@ -0,0 +1,755 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
from tiktoken import get_encoding, encoding_for_model
|
3 |
+
from weaviate_interface import WeaviateClient, WhereFilter
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from prompt_templates import question_answering_prompt_series, question_answering_system
|
6 |
+
from openai_interface import GPT_Turbo
|
7 |
+
from app_features import (convert_seconds, generate_prompt_series, search_result,
|
8 |
+
validate_token_threshold, load_content_cache, load_data,
|
9 |
+
expand_content)
|
10 |
+
from retrieval_evaluation import execute_evaluation, calc_hit_rate_scores
|
11 |
+
from llama_index.finetuning import EmbeddingQAFinetuneDataset
|
12 |
+
from weaviate_interface import WeaviateClient
|
13 |
+
from openai import BadRequestError
|
14 |
+
from reranker import ReRanker
|
15 |
+
from loguru import logger
|
16 |
+
import streamlit as st
|
17 |
+
from streamlit_option_menu import option_menu
|
18 |
+
import hydralit_components as hc
|
19 |
+
import sys
|
20 |
+
import json
|
21 |
+
import os, time, requests, re
|
22 |
+
from datetime import timedelta
|
23 |
+
import pathlib
|
24 |
+
import gdown
|
25 |
+
import tempfile
|
26 |
+
import base64
|
27 |
+
import shutil
|
28 |
+
|
29 |
+
def get_base64_of_bin_file(bin_file):
|
30 |
+
with open(bin_file, 'rb') as file:
|
31 |
+
data = file.read()
|
32 |
+
return base64.b64encode(data).decode()
|
33 |
+
|
34 |
+
from dotenv import load_dotenv, find_dotenv
|
35 |
+
load_dotenv(find_dotenv('env'), override=True)
|
36 |
+
|
37 |
+
# I use a key that I increment each time I want to change a text_input
|
38 |
+
if 'key' not in st.session_state:
|
39 |
+
st.session_state.key = 0
|
40 |
+
# key = st.session_state['key']
|
41 |
+
|
42 |
+
if not pathlib.Path('models').exists():
|
43 |
+
os.mkdir('models')
|
44 |
+
|
45 |
+
# I should cache these things but no time left
|
46 |
+
|
47 |
+
# I put a file local.txt in my desktop models folder to find out if it's running online
|
48 |
+
we_are_online = not pathlib.Path("models/local.txt").exists()
|
49 |
+
we_are_not_online = not we_are_online
|
50 |
+
|
51 |
+
golden_dataset = EmbeddingQAFinetuneDataset.from_json("data/golden_100.json")
|
52 |
+
|
53 |
+
# shutil.rmtree("models/models") # remove it - I wanted to clear the space on streamlit online
|
54 |
+
|
55 |
+
## PAGE CONFIGURATION
|
56 |
+
st.set_page_config(page_title="Ask Impact Theory",
|
57 |
+
page_icon="assets/impact-theory-logo-only.png",
|
58 |
+
layout="wide",
|
59 |
+
initial_sidebar_state="collapsed",
|
60 |
+
menu_items={'Report a bug': "https://www.extremelycoolapp.com/bug"})
|
61 |
+
|
62 |
+
|
63 |
+
image = "https://is2-ssl.mzstatic.com/image/thumb/Music122/v4/bd/34/82/bd348260-314c-5898-26c0-bef2e0388ebe/source/1200x1200bb.png"
|
64 |
+
|
65 |
+
|
66 |
+
def add_bg_from_local(image_file):
|
67 |
+
bin_str = get_base64_of_bin_file(image_file)
|
68 |
+
page_bg_img = f'''
|
69 |
+
<style>
|
70 |
+
.stApp {{
|
71 |
+
background-image: url("data:image/png;base64,{bin_str}");
|
72 |
+
background-size: 100% auto;
|
73 |
+
background-repeat: no-repeat;
|
74 |
+
background-attachment: fixed;
|
75 |
+
}}
|
76 |
+
</style>
|
77 |
+
'''
|
78 |
+
|
79 |
+
st.markdown(page_bg_img, unsafe_allow_html=True)
|
80 |
+
|
81 |
+
# COMMENT: I tried to create a dropdown menu but it's harder than it looks, so I gave up
|
82 |
+
# https://discuss.streamlit.io/t/streamlit-option-menu-is-a-simple-streamlit-component-that-allows-users-to-select-a-single-item-from-a-list-of-options-in-a-menu/20514
|
83 |
+
# not great, but it works
|
84 |
+
# selected = option_menu("About", ["Improvements","This"], #"Main Menu", ["Home", 'Settings'],
|
85 |
+
# icons=['house', 'gear'],
|
86 |
+
# menu_icon="cast",
|
87 |
+
# default_index=1)
|
88 |
+
|
89 |
+
# # Custom HTML/CSS for the banner
|
90 |
+
# base64_img = get_base64_of_bin_file("assets/it_tom_bilyeu.png")
|
91 |
+
# banner_menu_html = f"""
|
92 |
+
# <div class="banner">
|
93 |
+
# <img src= "data:image/png;base64,{base64_img}" alt="Banner Image">
|
94 |
+
# </div>
|
95 |
+
# <style>
|
96 |
+
# .banner {{
|
97 |
+
# width: 100%;
|
98 |
+
# height: auto;
|
99 |
+
# overflow: hidden;
|
100 |
+
# display: flex;
|
101 |
+
# justify-content: center;
|
102 |
+
# }}
|
103 |
+
# .banner img {{
|
104 |
+
# width: 130%;
|
105 |
+
# height: auto;
|
106 |
+
# object-fit: contain;
|
107 |
+
# }}
|
108 |
+
# </style>
|
109 |
+
# """
|
110 |
+
# st.components.v1.html(banner_menu_html)
|
111 |
+
|
112 |
+
|
113 |
+
# specify the primary menu definition
|
114 |
+
# it gives a vertical menu inside a navigation bar !!!
|
115 |
+
# menu_data = [
|
116 |
+
# {'icon': "far fa-copy", 'label':"Left End"},
|
117 |
+
# {'id':'Copy','icon':"🐙",'label':"Copy"},
|
118 |
+
# {'icon': "far fa-chart-bar", 'label':"Chart"},#no tooltip message
|
119 |
+
# {'icon': "far fa-address-book", 'label':"Book"},
|
120 |
+
# {'id':' Crazy return value 💀','icon': "💀", 'label':"Calendar"},
|
121 |
+
# {'icon': "far fa-clone", 'label':"Component"},
|
122 |
+
# {'icon': "fas fa-tachometer-alt", 'label':"Dashboard",'ttip':"I'm the Dashboard tooltip!"}, #can add a tooltip message
|
123 |
+
# {'icon': "far fa-copy", 'label':"Right End"},
|
124 |
+
# ]
|
125 |
+
# # we can override any part of the primary colors of the menu
|
126 |
+
# over_theme = {'txc_inactive': '#FFFFFF','menu_background':'red','txc_active':'yellow','option_active':'blue'}
|
127 |
+
# # over_theme = {'txc_inactive': '#FFFFFF'}
|
128 |
+
# menu_id = hc.nav_bar(menu_definition=menu_data,
|
129 |
+
# home_name='Home',
|
130 |
+
# override_theme=over_theme)
|
131 |
+
#get the id of the menu item clicked
|
132 |
+
# st.info(f"{menu_id=}")
|
133 |
+
## RERANKER
|
134 |
+
reranker = ReRanker('cross-encoder/ms-marco-MiniLM-L-6-v2')
|
135 |
+
## ENCODING --> tiktoken library
|
136 |
+
model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
|
137 |
+
model_nameGPT = model_ids[1]
|
138 |
+
encoding = encoding_for_model(model_nameGPT)
|
139 |
+
# = get_encoding('gpt-3.5-turbo-0613')
|
140 |
+
##############
|
141 |
+
data_path = './data/impact_theory_data.json'
|
142 |
+
cache_path = 'data/impact_theory_cache.parquet'
|
143 |
+
data = load_data(data_path)
|
144 |
+
cache = None # load_content_cache(cache_path)
|
145 |
+
|
146 |
+
try:
|
147 |
+
# st.write("Loading secrets from secrets.toml")
|
148 |
+
Wapi_key = st.secrets['secrets']['WEAVIATE_API_KEY']
|
149 |
+
url = st.secrets['secrets']['WEAVIATE_ENDPOINT']
|
150 |
+
openai_api_key = st.secrets['secrets']['OPENAI_API_KEY']
|
151 |
+
|
152 |
+
hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN_chris']
|
153 |
+
hf_endpoint = st.secrets['secrets']['LLAMA2_ENDPOINT_UPLIMIT']
|
154 |
+
# st.write("Secrets loaded from secrets.toml")
|
155 |
+
# st.write("HF_TOKEN", hf_token)
|
156 |
+
except:
|
157 |
+
st.write("Loading secrets from environment variables")
|
158 |
+
api_key = os.environ['WEAVIATE_API_KEY']
|
159 |
+
url = os.environ['WEAVIATE_ENDPOINT']
|
160 |
+
openai_api_key = os.environ['OPENAI_API_KEY']
|
161 |
+
|
162 |
+
hf_token = os.environ['LLAMA2_ENDPOINT_HF_TOKEN_chris']
|
163 |
+
hf_endpoint = os.environ['LLAMA2_ENDPOINT_UPLIMIT']
|
164 |
+
#%%
|
165 |
+
# model_default = 'sentence-transformers/all-mpnet-base-v2'
|
166 |
+
model_default = 'models/finetuned-all-mpnet-base-v2-300' if we_are_not_online \
|
167 |
+
else 'sentence-transformers/all-mpnet-base-v2'
|
168 |
+
|
169 |
+
available_models = ['sentence-transformers/all-mpnet-base-v2',
|
170 |
+
'sentence-transformers/all-MiniLM-L6-v2',
|
171 |
+
'models/finetuned-all-mpnet-base-v2-300']
|
172 |
+
|
173 |
+
#%%
|
174 |
+
models_urls = {'models/finetuned-all-mpnet-base-v2-300': "https://drive.google.com/drive/folders/1asJ37-AUv5nytLtH6hp6_bVV3_cZOXfj"}
|
175 |
+
|
176 |
+
def download_model_from_Gdrive(model_name_or_path, model_full_path):
|
177 |
+
print("Downloading model from Google Drive")
|
178 |
+
st.write("Downloading model from Google Drive")
|
179 |
+
assert model_name_or_path in models_urls, f"Model {model_name_or_path} not found in models_urls"
|
180 |
+
url = models_urls[model_name_or_path]
|
181 |
+
gdown.download_folder(url, output=model_full_path, quiet=False, use_cookies=False)
|
182 |
+
print("Model downloaded and saved to models folder")
|
183 |
+
# st.write("Model downloaded")
|
184 |
+
|
185 |
+
def download_model(model_name_or_path, model_full_path):
|
186 |
+
|
187 |
+
if model_name_or_path.startswith("models/"):
|
188 |
+
download_model_from_Gdrive(model_name_or_path, model_full_path)
|
189 |
+
print(f"Model {model_full_path} downloaded")
|
190 |
+
models_urls[model_name_or_path] = model_full_path
|
191 |
+
# st.sidebar.write(f"Model {model_full_path} downloaded")
|
192 |
+
|
193 |
+
elif model_name_or_path.startswith("sentence-transformers/"):
|
194 |
+
st.sidebar.write(f"Downloading Sentence Transformer model {model_name_or_path}")
|
195 |
+
model = SentenceTransformer(model_name_or_path) # HF looks into its own models folder/path
|
196 |
+
models_urls[model_name_or_path] = model_full_path
|
197 |
+
# st.sidebar.write(f"Model {model_name_or_path} downloaded")
|
198 |
+
model.save(model_full_path)
|
199 |
+
# st.sidebar.write(f"Model {model_name_or_path} saved to {model_full_path}")
|
200 |
+
|
201 |
+
# if 'modelspath' not in st.session_state:
|
202 |
+
# st.session_state['modelspath'] = None
|
203 |
+
# if st.session_state.modelspath is None:
|
204 |
+
# # let's create a temp folder on the first run
|
205 |
+
# persistent_dir = pathlib.Path("path/to/persistent_dir")
|
206 |
+
# persistent_dir.mkdir(parents=True, exist_ok=True)
|
207 |
+
# with tempfile.TemporaryDirectory() as temp_dir:
|
208 |
+
# st.session_state.modelspath = temp_dir
|
209 |
+
# print(f"Temporary directory created at {temp_dir}")
|
210 |
+
# # the temp folder disappears with the context, but not the one we've created manually
|
211 |
+
# else:
|
212 |
+
# temp_dir = st.session_state.modelspath
|
213 |
+
# print(f"Temporary directory already exists at {temp_dir}")
|
214 |
+
# # st.write(os.listdir(temp_dir))
|
215 |
+
|
216 |
+
#%%
|
217 |
+
# for streamlit online, we must download the model from google drive
|
218 |
+
# because github LFS doesn't work on forked repos
|
219 |
+
def check_model(model_name_or_path):
|
220 |
+
|
221 |
+
model_path = pathlib.Path(model_name_or_path)
|
222 |
+
model_full_path = str(pathlib.Path("models") / model_path) # this creates a models folder inside /models
|
223 |
+
model_full_path = model_full_path.replace("sentence-transformers/", "models/") # all are saved in models folder
|
224 |
+
|
225 |
+
if pathlib.Path(model_full_path).exists():
|
226 |
+
# let's use the model that's already there
|
227 |
+
print(f"Model {model_full_path} already exists")
|
228 |
+
|
229 |
+
|
230 |
+
# but delete everything else in we are online because
|
231 |
+
# streamlit online has limited space (and will shut down the app if it's full)
|
232 |
+
if we_are_online:
|
233 |
+
# st.sidebar.write(f"Model {model_full_path} already exists")
|
234 |
+
# st.sidebar.write(f"Deleting other models")
|
235 |
+
dirs = os.listdir("models/models")
|
236 |
+
# we get only the folder name, not the full path
|
237 |
+
dirs.remove(model_full_path.split('/')[-1])
|
238 |
+
for p in dirs:
|
239 |
+
dirpath = pathlib.Path("models/models") / p
|
240 |
+
if dirpath.is_dir():
|
241 |
+
shutil.rmtree(dirpath)
|
242 |
+
else:
|
243 |
+
|
244 |
+
if we_are_online:
|
245 |
+
# space issues on streamlit online, let's not leave anything behind
|
246 |
+
# and redownload the model eveery time
|
247 |
+
print("Deleting models/models folder")
|
248 |
+
if pathlib.Path('models/models').exists():
|
249 |
+
shutil.rmtree("models/models") # make room, if other models are there
|
250 |
+
# st.sidebar.write(f"models/models folder deleted")
|
251 |
+
|
252 |
+
download_model(model_name_or_path, model_full_path)
|
253 |
+
|
254 |
+
return model_full_path
|
255 |
+
|
256 |
+
#%% instantiate Weaviate client
|
257 |
+
def get_weaviate_client(api_key, url, model_name_or_path, openai_api_key):
|
258 |
+
client = WeaviateClient(api_key, url,
|
259 |
+
model_name_or_path=model_name_or_path,
|
260 |
+
openai_api_key=openai_api_key)
|
261 |
+
client.display_properties.append('summary')
|
262 |
+
available_classes = sorted(client.show_classes())
|
263 |
+
# st.write(f"Available classes: {available_classes}")
|
264 |
+
# st.write(f"Available classes type: {type(available_classes)}")
|
265 |
+
logger.info(available_classes)
|
266 |
+
return client, available_classes
|
267 |
+
|
268 |
+
|
269 |
+
##############
|
270 |
+
# data = load_data(data_path)
|
271 |
+
# guests list for sidebar
|
272 |
+
guest_list = sorted(list(set([d['guest'] for d in data])))
|
273 |
+
|
274 |
+
def main():
|
275 |
+
|
276 |
+
with st.sidebar:
|
277 |
+
# moved it to main area
|
278 |
+
# guest = st.selectbox('Select Guest',
|
279 |
+
# options=guest_list,
|
280 |
+
# index=None,
|
281 |
+
# placeholder='Select Guest')
|
282 |
+
_, center, _ = st.columns([3, 5, 3])
|
283 |
+
with center:
|
284 |
+
st.text("Search Lab")
|
285 |
+
|
286 |
+
_, center, _ = st.columns([2, 5, 3])
|
287 |
+
with center:
|
288 |
+
if we_are_online:
|
289 |
+
st.text("Running ONLINE")
|
290 |
+
st.text("(UNSTABLE)")
|
291 |
+
else:
|
292 |
+
st.text("Running OFFLINE")
|
293 |
+
st.write("----------")
|
294 |
+
|
295 |
+
alpha_input = st.slider(label='Alpha',min_value=0.00, max_value=1.00, value=0.40, step=0.05)
|
296 |
+
retrieval_limit = st.slider(label='Hybrid Search Results', min_value=10, max_value=300, value=10, step=10)
|
297 |
+
|
298 |
+
hybrid_filter = st.toggle('Filter Guest', True) # i.e. look only at guests' data
|
299 |
+
|
300 |
+
rerank = st.toggle('Use Reranker', True)
|
301 |
+
if rerank:
|
302 |
+
reranker_topk = st.slider(label='Reranker Top K',min_value=1, max_value=5, value=3, step=1)
|
303 |
+
else:
|
304 |
+
# needed to not fill the LLM with too many responses (> context size)
|
305 |
+
# we could make it dependent on the model
|
306 |
+
reranker_topk = 3
|
307 |
+
|
308 |
+
rag_it = st.toggle('RAG it', True)
|
309 |
+
if rag_it:
|
310 |
+
st.sidebar.write(f"Using LLM '{model_nameGPT}'")
|
311 |
+
llm_temperature = st.slider(label='LLM T˚', min_value=0.0, max_value=2.0, value=0.01, step=0.10 )
|
312 |
+
|
313 |
+
model_name_or_path = st.selectbox(label='Model Name:', options=available_models,
|
314 |
+
index=available_models.index(model_default),
|
315 |
+
placeholder='Select Model')
|
316 |
+
|
317 |
+
st.write("Experimental and time limited 2'")
|
318 |
+
finetune_model = st.toggle('Finetune on Modal A100 GPU', False)
|
319 |
+
if finetune_model:
|
320 |
+
from finetune_backend import finetune
|
321 |
+
if 'finetuned' in model_name_or_path:
|
322 |
+
st.write("Model already finetuned")
|
323 |
+
elif model_name_or_path.startswith("models/"):
|
324 |
+
st.write("Sentence Transformers models only!")
|
325 |
+
else:
|
326 |
+
try:
|
327 |
+
if 'finetuned' in finetune_model:
|
328 |
+
st.write("Model already finetuned")
|
329 |
+
else:
|
330 |
+
model_path = finetune(model_name_or_path, savemodel=True, outpath='models')
|
331 |
+
if model_path is not None:
|
332 |
+
if finetune_model.split('/')[-1] not in model_path:
|
333 |
+
st.write(model_path) # a warning from finetuning in this case
|
334 |
+
elif model_path not in available_models:
|
335 |
+
# finetuning generated a model, let's add it
|
336 |
+
available_models.append(model_path)
|
337 |
+
st.write("Model saved!")
|
338 |
+
except Exception:
|
339 |
+
st.write("Model not found on HF or error")
|
340 |
+
|
341 |
+
model_name_or_path = check_model(model_name_or_path)
|
342 |
+
client, available_classes = get_weaviate_client(Wapi_key, url, model_name_or_path, openai_api_key)
|
343 |
+
|
344 |
+
start_class = 'Impact_theory_all_mpnet_base_v2_finetuned'
|
345 |
+
|
346 |
+
class_name = st.selectbox(
|
347 |
+
label='Class Name:',
|
348 |
+
options=available_classes,
|
349 |
+
index=available_classes.index(start_class),
|
350 |
+
placeholder='Select Class Name'
|
351 |
+
)
|
352 |
+
|
353 |
+
st.write("----------")
|
354 |
+
|
355 |
+
c1,c2 = st.columns([8,1])
|
356 |
+
with c1:
|
357 |
+
show_metrics = st.toggle('Show Metrics on Golden set', False)
|
358 |
+
if show_metrics:
|
359 |
+
# _, center, _ = st.columns([3, 5, 3])
|
360 |
+
# with center:
|
361 |
+
# st.text("Metrics")
|
362 |
+
with c2:
|
363 |
+
with st.spinner(''):
|
364 |
+
metrics = execute_evaluation(golden_dataset, class_name, client, alpha=alpha_input)
|
365 |
+
if show_metrics:
|
366 |
+
kw_hit_rate = metrics['kw_hit_rate']
|
367 |
+
kw_mrr = metrics['kw_mrr']
|
368 |
+
hybrid_hit_rate = metrics['hybrid_hit_rate']
|
369 |
+
vector_hit_rate = metrics['vector_hit_rate']
|
370 |
+
vector_mrr = metrics['vector_mrr']
|
371 |
+
total_misses = metrics['total_misses']
|
372 |
+
|
373 |
+
st.text(f"KW hit rate: {kw_hit_rate}")
|
374 |
+
st.text(f"Vector hit rate: {vector_hit_rate}")
|
375 |
+
st.text(f"Hybrid hit rate: {hybrid_hit_rate}")
|
376 |
+
st.text(f"Hybrid MRR: {vector_mrr}")
|
377 |
+
st.text(f"Total misses: {total_misses}")
|
378 |
+
|
379 |
+
st.write("----------")
|
380 |
+
|
381 |
+
st.title("Chat with the Impact Theory podcasts!")
|
382 |
+
# st.image('./assets/impact-theory-logo.png', width=400)
|
383 |
+
st.image('assets/it_tom_bilyeu.png', use_column_width=True)
|
384 |
+
# st.subheader(f"Chat with the Impact Theory podcast: ")
|
385 |
+
st.write('\n')
|
386 |
+
# st.stop()
|
387 |
+
|
388 |
+
|
389 |
+
st.write("\u21D0 Open the sidebar to change Search settings \n ") # https://home.unicode.org also 21E0, 21B0 B2 D0
|
390 |
+
guest = st.selectbox('Select A Guest',
|
391 |
+
options=guest_list,
|
392 |
+
index=None,
|
393 |
+
placeholder='Select Guest')
|
394 |
+
|
395 |
+
|
396 |
+
col1, col2 = st.columns([7,3])
|
397 |
+
with col1:
|
398 |
+
if guest is None:
|
399 |
+
msg = f'Select a guest before asking your question:'
|
400 |
+
else:
|
401 |
+
msg = f'Enter your question about {guest}:'
|
402 |
+
|
403 |
+
textbox = st.empty()
|
404 |
+
# best solution I found to be able to change the text inside a text_input box afterwards, using a key
|
405 |
+
query = textbox.text_input(msg,
|
406 |
+
value="",
|
407 |
+
placeholder="You can refer to the guest with pronoun or drop the question mark",
|
408 |
+
key=st.session_state.key)
|
409 |
+
|
410 |
+
# st.write(f"Guest = {guest}")
|
411 |
+
# st.write(f"key = {st.session_state.key}")
|
412 |
+
|
413 |
+
st.write('\n\n\n\n\n')
|
414 |
+
|
415 |
+
reworded_query = {'changed': False, 'status': 'error'} # at start, the query is empty
|
416 |
+
valid_response = [] # at start, the query is empty, so prevent the search
|
417 |
+
|
418 |
+
if query:
|
419 |
+
|
420 |
+
if guest is None:
|
421 |
+
st.session_state.key += 1
|
422 |
+
query = textbox.text_input(msg,
|
423 |
+
value="",
|
424 |
+
placeholder="YOU MUST SELECT A GUEST BEFORE ASKING A QUESTION",
|
425 |
+
key=st.session_state.key)
|
426 |
+
# st.write(f"key = {st.session_state.key}")
|
427 |
+
st.stop()
|
428 |
+
else:
|
429 |
+
# st.write(f'It looks like you selected {guest} as a filter (It is ignored for now).')
|
430 |
+
|
431 |
+
with col2:
|
432 |
+
# let's add a nice pulse bar while generating the response
|
433 |
+
with hc.HyLoader('', hc.Loaders.pulse_bars, primary_color= 'red', height=50): #"#0e404d" for image green
|
434 |
+
# with st.spinner('Generating Response...'):
|
435 |
+
|
436 |
+
with col1:
|
437 |
+
|
438 |
+
# let's use Llama2 here
|
439 |
+
reworded_query = reword_query(query, guest,
|
440 |
+
model_name='llama2-13b-chat')
|
441 |
+
query = reworded_query['rewritten_question']
|
442 |
+
|
443 |
+
# we can arrive here only if a guest was selected
|
444 |
+
where_filter = WhereFilter(path=['guest'], operator='Equal', valueText=guest).todict() \
|
445 |
+
if hybrid_filter else None
|
446 |
+
|
447 |
+
hybrid_response = client.hybrid_search(query,
|
448 |
+
class_name,
|
449 |
+
# properties=['content'], #['title', 'summary', 'content'],
|
450 |
+
alpha=alpha_input,
|
451 |
+
display_properties=client.display_properties,
|
452 |
+
where_filter=where_filter,
|
453 |
+
limit=retrieval_limit)
|
454 |
+
response = hybrid_response
|
455 |
+
|
456 |
+
if rerank:
|
457 |
+
# rerank results with cross encoder
|
458 |
+
ranked_response = reranker.rerank(response, query,
|
459 |
+
apply_sigmoid=True, # score between 0 and 1
|
460 |
+
top_k=reranker_topk)
|
461 |
+
logger.info(ranked_response)
|
462 |
+
expanded_response = expand_content(ranked_response, cache,
|
463 |
+
content_key='doc_id',
|
464 |
+
create_new_list=True)
|
465 |
+
|
466 |
+
response = expanded_response
|
467 |
+
|
468 |
+
# make sure token count < threshold
|
469 |
+
token_threshold = 8000 if model_nameGPT == model_ids[0] else 3500
|
470 |
+
valid_response = validate_token_threshold(response,
|
471 |
+
question_answering_prompt_series,
|
472 |
+
query=query,
|
473 |
+
tokenizer= encoding,# variable from ENCODING,
|
474 |
+
token_threshold=token_threshold,
|
475 |
+
verbose=True)
|
476 |
+
# st.write(f"Number of results: {len(valid_response)}")
|
477 |
+
|
478 |
+
|
479 |
+
# I jump out of col1 to get all page width, so need to retest query
|
480 |
+
if query is not None and reworded_query['status'] != 'error':
|
481 |
+
show_query = st.toggle('Show rewritten query', False)
|
482 |
+
if show_query: # or reworded_query['changed']:
|
483 |
+
st.write(f"Rewritten query: {query}")
|
484 |
+
|
485 |
+
# creates container for LLM response to position it above search results
|
486 |
+
chat_container, response_box = [], st.empty()
|
487 |
+
# # RAG time !! execute chat call to LLM
|
488 |
+
if rag_it:
|
489 |
+
# st.subheader("Response from Impact Theory (context)")
|
490 |
+
# will appear under the answer, moved it into the response box
|
491 |
+
|
492 |
+
# generate LLM prompt
|
493 |
+
prompt = generate_prompt_series(query=query, results=valid_response)
|
494 |
+
|
495 |
+
|
496 |
+
GPTllm = GPT_Turbo(model=model_nameGPT,
|
497 |
+
api_key=st.secrets['secrets']['OPENAI_API_KEY'])
|
498 |
+
try:
|
499 |
+
# inserts chat stream from LLM
|
500 |
+
for resp in GPTllm.get_chat_completion(prompt=prompt,
|
501 |
+
temperature=llm_temperature,
|
502 |
+
max_tokens=350,
|
503 |
+
show_response=True,
|
504 |
+
stream=True):
|
505 |
+
|
506 |
+
with response_box:
|
507 |
+
content = resp.choices[0].delta.content
|
508 |
+
if content:
|
509 |
+
chat_container.append(content)
|
510 |
+
result = "".join(chat_container).strip()
|
511 |
+
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
|
512 |
+
except BadRequestError as e:
|
513 |
+
logger.info('Making request with smaller context')
|
514 |
+
|
515 |
+
valid_response = validate_token_threshold(response,
|
516 |
+
question_answering_prompt_series,
|
517 |
+
query=query,
|
518 |
+
tokenizer=encoding,
|
519 |
+
token_threshold=3500,
|
520 |
+
verbose=True)
|
521 |
+
# if reranker is off, we may receive a LOT of responses
|
522 |
+
# so we must reduce the context size manually
|
523 |
+
if not rerank:
|
524 |
+
valid_response = valid_response[:reranker_topk]
|
525 |
+
|
526 |
+
prompt = generate_prompt_series(query=query, results=valid_response)
|
527 |
+
for resp in GPTllm.get_chat_completion(prompt=prompt,
|
528 |
+
temperature=llm_temperature,
|
529 |
+
max_tokens=350, # expand for more verbose answers
|
530 |
+
show_response=True,
|
531 |
+
stream=True):
|
532 |
+
try:
|
533 |
+
# inserts chat stream from LLM
|
534 |
+
with response_box:
|
535 |
+
content = resp.choice[0].delta.content
|
536 |
+
if content:
|
537 |
+
chat_container.append(content)
|
538 |
+
result = "".join(chat_container).strip()
|
539 |
+
response_box.markdown(f"### Response from Impact Theory (RAG):\n\n{result}")
|
540 |
+
except Exception as e:
|
541 |
+
print(e)
|
542 |
+
|
543 |
+
st.markdown("----")
|
544 |
+
st.subheader("Search Results")
|
545 |
+
|
546 |
+
for i, hit in enumerate(valid_response):
|
547 |
+
col1, col2 = st.columns([7, 3], gap='large')
|
548 |
+
image = hit['thumbnail_url'] # get thumbnail_url
|
549 |
+
episode_url = hit['episode_url'] # get episode_url
|
550 |
+
title = hit["title"] # get title
|
551 |
+
show_length = hit["length"] # get length
|
552 |
+
time_string = str(timedelta(seconds=show_length)) # convert show_length to readable time string
|
553 |
+
|
554 |
+
with col1:
|
555 |
+
st.write(search_result(i=i,
|
556 |
+
url=episode_url,
|
557 |
+
guest=hit['guest'],
|
558 |
+
title=title,
|
559 |
+
content='',
|
560 |
+
length=time_string),
|
561 |
+
unsafe_allow_html=True)
|
562 |
+
st.write('\n\n')
|
563 |
+
|
564 |
+
with col2:
|
565 |
+
#st.write(f"<a href={episode_url} <img src={image} width='200'></a>",
|
566 |
+
# unsafe_allow_html=True)
|
567 |
+
#st.markdown(f"[![{title}]({image})]({episode_url})")
|
568 |
+
# st.markdown(f'<a href="{episode_url}">'
|
569 |
+
# f'<img src={image} '
|
570 |
+
# f'caption={title.split("|")[0]} width=200, use_column_width=False />'
|
571 |
+
# f'</a>',
|
572 |
+
# unsafe_allow_html=True)
|
573 |
+
|
574 |
+
st.image(image, caption=title.split('|')[0], width=200, use_column_width=False)
|
575 |
+
# let's use all width for the content
|
576 |
+
st.write(hit['content'])
|
577 |
+
|
578 |
+
|
579 |
+
def get_answer(query, valid_response, GPTllm):
|
580 |
+
|
581 |
+
# generate LLM prompt
|
582 |
+
prompt = generate_prompt_series(query=query,
|
583 |
+
results=valid_response)
|
584 |
+
|
585 |
+
return GPTllm.get_chat_completion(prompt=prompt,
|
586 |
+
system_message='answer this question based on the podcast material',
|
587 |
+
temperature=0,
|
588 |
+
max_tokens=500,
|
589 |
+
stream=False,
|
590 |
+
show_response=False)
|
591 |
+
|
592 |
+
def reword_query(query, guest, model_name='llama2-13b-chat', response_processing=True):
|
593 |
+
""" Asks LLM to rewrite the query when the guest name is missing.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
query (str): user query
|
597 |
+
guest (str): guest name
|
598 |
+
model_name (str, optional): name of a LLM model to be used
|
599 |
+
"""
|
600 |
+
|
601 |
+
# tags = {'llama2-13b-chat': {'start': '<s>', 'end': '</s>', 'instruction': '[INST]', 'system': '[SYS]'},
|
602 |
+
# 'gpt-3.5-turbo-0613': {'start': '<|startoftext|>', 'end': '', 'instruction': "```", 'system': ```}}
|
603 |
+
|
604 |
+
prompt_fields = {
|
605 |
+
"you_are":f"You are an expert in linguistics and semantics, analyzing the question asked by a user to a vector search system, \
|
606 |
+
and making sure that the question is well formulated and that the system can understand it.",
|
607 |
+
|
608 |
+
"your_task":f"Your task is to detect if the name of the guest ({guest}) is mentioned in the user's question, \
|
609 |
+
and if that is not the case, rewrite the question using the guest name, \
|
610 |
+
without changing the meaning of the question. \
|
611 |
+
Most of the time, the user will have used a pronoun to designate the guest, in which case, \
|
612 |
+
simply replace the pronoun with the guest name.",
|
613 |
+
|
614 |
+
"question":f"If the user mentions the guest name, ie {query}, just return his question as is. \
|
615 |
+
If the user does not mention the guest name, rewrite the question using the guest name.",
|
616 |
+
|
617 |
+
"final_instruction":f"Only regerate the requested rewritten question or the original, WITHOUT ANY COMMENT OR REPHRASING. \
|
618 |
+
Your answer must be as close as possible to the original question, \
|
619 |
+
and exactly identical, word for word, if the user mentions the guest name, i.e. {guest}.",
|
620 |
+
}
|
621 |
+
|
622 |
+
# prompt created by chatGPT :-)
|
623 |
+
# and Llama still outputs the original question and precedes the answer with 'rewritten question'
|
624 |
+
prompt_fields2 = {
|
625 |
+
"you_are": (
|
626 |
+
"You are an expert in linguistics and semantics. Your role is to analyze questions asked to a vector search system."
|
627 |
+
),
|
628 |
+
"your_task": (
|
629 |
+
f"Detect if the guest's FULL name, {guest}, is mentioned in the user's question. "
|
630 |
+
"If not, rewrite the question by replacing pronouns or indirect references with the guest's name." \
|
631 |
+
"If yes, return the original question as is, without any change at all, not even punctuation,"
|
632 |
+
"except a question mark that you MUST add if it's missing."
|
633 |
+
),
|
634 |
+
"question": (
|
635 |
+
f"Original question: '{query}'. "
|
636 |
+
"Rewrite this question to include the guest's FULL name if it's not already mentioned."
|
637 |
+
"The Only thing you can and MUST add is a question mark if it's missing."
|
638 |
+
),
|
639 |
+
"final_instruction": (
|
640 |
+
"Create a rewritten question or keep the original question as is. "
|
641 |
+
"Do not include any labels, titles, or additional text before or after the question."
|
642 |
+
"The Only thing you can and MUST add is a question mark if it's missing."
|
643 |
+
"Return a json object, with the key 'original_question' for the original question, \
|
644 |
+
and 'rewritten_question' for the rewritten question \
|
645 |
+
and 'changed' being True if you changed the answer, otherwise False."
|
646 |
+
),
|
647 |
+
}
|
648 |
+
|
649 |
+
|
650 |
+
if model_name == 'llama2-13b-chat':
|
651 |
+
# special tags are used:
|
652 |
+
# `<s>` - start prompt tag
|
653 |
+
# `[INST], [/INST]` - Opening and closing model instruction tags
|
654 |
+
# `<<<SYS>>>, <</SYS>>` - Opening and closing system prompt tags
|
655 |
+
llama_prompt = """
|
656 |
+
<s>[INST] <<SYS>>
|
657 |
+
{you_are}
|
658 |
+
<</SYS>>
|
659 |
+
{your_task}\n
|
660 |
+
|
661 |
+
```
|
662 |
+
\n\n
|
663 |
+
Question: {question}\n
|
664 |
+
{final_instruction} [/INST]
|
665 |
+
|
666 |
+
Answer:
|
667 |
+
"""
|
668 |
+
prompt = llama_prompt.format(**prompt_fields2)
|
669 |
+
|
670 |
+
hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN_chris']
|
671 |
+
# hf_token = st.secrets['secrets']['LLAMA2_ENDPOINT_HF_TOKEN']
|
672 |
+
|
673 |
+
hf_endpoint = st.secrets['secrets']['LLAMA2_ENDPOINT_UPLIMIT']
|
674 |
+
|
675 |
+
headers = {"Authorization": f"Bearer {hf_token}",
|
676 |
+
"Content-Type": "application/json",}
|
677 |
+
|
678 |
+
json_body = {
|
679 |
+
"inputs": prompt,
|
680 |
+
"parameters": {"max_new_tokens":400,
|
681 |
+
"repetition_penalty": 1.0,
|
682 |
+
"temperature":0.01}
|
683 |
+
}
|
684 |
+
|
685 |
+
response = requests.request("POST", hf_endpoint, headers=headers, data=json.dumps(json_body))
|
686 |
+
response = json.loads(response.content.decode("utf-8"))
|
687 |
+
# ^ will not process the badly formatted generated text, so we do it ourselves
|
688 |
+
|
689 |
+
if isinstance(response, dict) and 'error' in response:
|
690 |
+
print("Found error")
|
691 |
+
print(response)
|
692 |
+
# return {'error': response['error'], 'rewritten_question': query, 'changed': False, 'status': 'error'}
|
693 |
+
# I test this here otherwise it gets in col 2 or 1, which are too
|
694 |
+
# if reworded_query['status'] == 'error':
|
695 |
+
# st.write(f"Error in LLM response: 'error':{reworded_query['error']}")
|
696 |
+
# st.write("The LLM could not connect to the server. Please try again later.")
|
697 |
+
# st.stop()
|
698 |
+
return reword_query(query, guest, model_name='gpt-3.5-turbo-0613')
|
699 |
+
|
700 |
+
if response_processing:
|
701 |
+
if isinstance(response, list) and isinstance(response[0], dict) and 'generated_text' in response[0]:
|
702 |
+
print("Found generated text")
|
703 |
+
response0 = response[0]['generated_text']
|
704 |
+
pattern = r'\"(\w+)\":\s*(\".*?\"|\w+)'
|
705 |
+
|
706 |
+
matches = re.findall(pattern, response0)
|
707 |
+
# let's build a dictionary
|
708 |
+
result = {key: json.loads(value) if value.startswith("\"") else value for key, value in matches}
|
709 |
+
return result | {'status': 'success'}
|
710 |
+
else:
|
711 |
+
print("Found no answer")
|
712 |
+
return reword_query(query, guest, model_name='gpt-3.5-turbo-0613')
|
713 |
+
# return {'original_question': query, 'rewritten_question': query, 'changed': False, 'status': 'no properly formatted answer' }
|
714 |
+
else:
|
715 |
+
return response
|
716 |
+
# return response
|
717 |
+
# assert 'error' not in response, f"Error in LLM response: {response['error']}"
|
718 |
+
# assert 'generated_text' in response[0], f"Error in LLM response: {response}, no 'generated_text' field"
|
719 |
+
# # let's extract the rewritten question
|
720 |
+
# return response[0]['generated_text'] .split("Rewritten question: '")[-1][:-1]
|
721 |
+
|
722 |
+
else:
|
723 |
+
# assume openai
|
724 |
+
model_ids = ['gpt-3.5-turbo-16k', 'gpt-3.5-turbo-0613']
|
725 |
+
model_name = model_ids[1]
|
726 |
+
GPTllm = GPT_Turbo(model=model_name,
|
727 |
+
api_key=st.secrets['secrets']['OPENAI_API_KEY'])
|
728 |
+
|
729 |
+
openai_prompt = """
|
730 |
+
{your_task}\n
|
731 |
+
```
|
732 |
+
\n\n
|
733 |
+
Question: {question}\n
|
734 |
+
{final_instruction}
|
735 |
+
|
736 |
+
Answer:
|
737 |
+
"""
|
738 |
+
prompt = openai_prompt.format(**prompt_fields)
|
739 |
+
|
740 |
+
try:
|
741 |
+
resp = GPTllm.get_chat_completion(prompt=openai_prompt,
|
742 |
+
system_message=prompt_fields['you_are'],
|
743 |
+
temperature=0.01,
|
744 |
+
max_tokens=1500, # it's a question...
|
745 |
+
show_response=True,
|
746 |
+
stream=False)
|
747 |
+
return {'rewritten_question': resp.choices[0].delta.content,
|
748 |
+
'changed': True, 'status': 'success'}
|
749 |
+
except Exception:
|
750 |
+
return {'rewritten_question': query, 'changed': False, 'status': 'not success'}
|
751 |
+
|
752 |
+
|
753 |
+
if __name__ == '__main__':
|
754 |
+
main()
|
755 |
+
# %%
|
backend.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
|
3 |
+
from typing import List, Dict, Tuple, Union, Callable
|
4 |
+
# from preprocessing import FileIO
|
5 |
+
|
6 |
+
# assets = modal.Mount.from_local_dir(
|
7 |
+
# "./data",
|
8 |
+
# # condition=lambda pth: not ".venv" in pth,
|
9 |
+
# remote_path="./data",
|
10 |
+
# )
|
11 |
+
|
12 |
+
|
13 |
+
stub = modal.Stub("vector-search-project")
|
14 |
+
vector_search = modal.Image.debian_slim().pip_install(
|
15 |
+
"sentence_transformers==2.2.2", "llama_index==0.9.6.post1", "angle_emb==0.1.5"
|
16 |
+
)
|
17 |
+
|
18 |
+
stub.volume = modal.Volume.new()
|
19 |
+
|
20 |
+
|
21 |
+
@stub.function(image=vector_search,
|
22 |
+
gpu="A100",
|
23 |
+
timeout=600,
|
24 |
+
volumes={"/root/models": stub.volume}
|
25 |
+
# secrets are available in the environment with os.environ["SECRET_NAME"]
|
26 |
+
# secret=modal.Secret.from_name("my-huggingface-secret")
|
27 |
+
)
|
28 |
+
def encode_content_splits(content_splits,
|
29 |
+
model=None, # path or name of model
|
30 |
+
**kwargs
|
31 |
+
):
|
32 |
+
""" kwargs provided in case encode method has extra arguments """
|
33 |
+
from sentence_transformers import SentenceTransformer
|
34 |
+
|
35 |
+
import os, time
|
36 |
+
models_list = os.listdir('/root/models')
|
37 |
+
print("Models:", models_list)
|
38 |
+
|
39 |
+
if isinstance(model, str) and model[-1] == "/":
|
40 |
+
model = model[:-1]
|
41 |
+
|
42 |
+
if isinstance(model, str):
|
43 |
+
model = model.split('/')[-1]
|
44 |
+
|
45 |
+
if isinstance(model, str) and model in models_list:
|
46 |
+
|
47 |
+
if "UAE-Large-V1-300" in model:
|
48 |
+
print("Loading finetuned UAE-Large-V1-300 model from Modal Volume")
|
49 |
+
|
50 |
+
from angle_emb import AnglE
|
51 |
+
model = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1',
|
52 |
+
pretrained_model_path=os.path.join('/root/models', model),
|
53 |
+
pooling_strategy='cls').cuda()
|
54 |
+
kwargs['to_numpy'] = True
|
55 |
+
|
56 |
+
# this model doesn't accept list of lists
|
57 |
+
if isinstance(content_splits[0], list):
|
58 |
+
content_splits = [chunk for episode in content_splits for chunk in episode]
|
59 |
+
|
60 |
+
else:
|
61 |
+
print(f"Loading model {model} from Modal volume")
|
62 |
+
model = SentenceTransformer(os.path.join('/root/models', model))
|
63 |
+
|
64 |
+
elif isinstance(model, str):
|
65 |
+
if model in models_list:
|
66 |
+
print(f"Loading model {model} from Modal volume")
|
67 |
+
model = SentenceTransformer(os.path.join('/root/models', model))
|
68 |
+
else:
|
69 |
+
print(f"Model {model} not found in Modal volume, loading from HuggingFace")
|
70 |
+
model = SentenceTransformer(model)
|
71 |
+
|
72 |
+
else:
|
73 |
+
print(f"Using model provided as argument")
|
74 |
+
if 'save' in kwargs:
|
75 |
+
if isinstance(kwargs['save'], str) and kwargs['save'][-1] == '/':
|
76 |
+
kwargs['save'] = kwargs['save'][:-1]
|
77 |
+
kwargs['save'] = kwargs['save'].split('/')[-1]
|
78 |
+
fname = os.path.join('/root/models', kwargs['save'])
|
79 |
+
print(f"Saving model in {fname}")
|
80 |
+
# model.save(fname)
|
81 |
+
print(f"Model saved in {fname}")
|
82 |
+
kwargs.pop('save')
|
83 |
+
|
84 |
+
print("Starting encoding")
|
85 |
+
start = time.perf_counter()
|
86 |
+
|
87 |
+
emb = [list(zip(episode, model.encode(episode, **kwargs))) for episode in content_splits]
|
88 |
+
end = time.perf_counter() - start
|
89 |
+
print(f"GPU processing lasted {end:.2f} seconds")
|
90 |
+
print("Encoding finished")
|
91 |
+
|
92 |
+
return emb
|
93 |
+
|
94 |
+
|
95 |
+
@stub.function(image=vector_search, gpu="A100", timeout=120,
|
96 |
+
mounts=[modal.Mount.from_local_dir("./data",
|
97 |
+
remote_path="/root/data",
|
98 |
+
condition=lambda pth: ".json" in pth)],
|
99 |
+
volumes={"/root/models": stub.volume}
|
100 |
+
)
|
101 |
+
def finetune(training_path='./data/training_data_300.json',
|
102 |
+
valid_path='./data/validation_data_100.json',
|
103 |
+
model_id=None):
|
104 |
+
|
105 |
+
import os
|
106 |
+
print("Data:", os.listdir('/root/data'))
|
107 |
+
print("Models:", os.listdir('/root/models'))
|
108 |
+
|
109 |
+
if model_id is None:
|
110 |
+
print("No model ID provided")
|
111 |
+
return None
|
112 |
+
elif isinstance(model_id, str) and model_id[-1] == "/":
|
113 |
+
model_id = model_id[:-1]
|
114 |
+
|
115 |
+
|
116 |
+
from llama_index.finetuning import EmbeddingQAFinetuneDataset
|
117 |
+
|
118 |
+
training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
|
119 |
+
valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)
|
120 |
+
print("Datasets loaded")
|
121 |
+
|
122 |
+
num_training_examples = len(training_set.queries)
|
123 |
+
print(f"Training examples: {num_training_examples}")
|
124 |
+
|
125 |
+
from llama_index.finetuning import SentenceTransformersFinetuneEngine
|
126 |
+
|
127 |
+
print(f"Model Name is {model_id}")
|
128 |
+
model_ext = model_id.split('/')[1]
|
129 |
+
|
130 |
+
ft_model_name = f'finetuned-{model_ext}-{num_training_examples}'
|
131 |
+
model_outpath = os.path.join("/root/models", ft_model_name)
|
132 |
+
|
133 |
+
print(f'Model ID: {model_id}')
|
134 |
+
print(f'Model Outpath: {model_outpath}')
|
135 |
+
|
136 |
+
finetune_engine = SentenceTransformersFinetuneEngine(
|
137 |
+
training_set,
|
138 |
+
batch_size=32,
|
139 |
+
model_id=model_id,
|
140 |
+
model_output_path=model_outpath,
|
141 |
+
val_dataset=valid_set,
|
142 |
+
epochs=10
|
143 |
+
)
|
144 |
+
import io, os, zipfile, glob, time
|
145 |
+
try:
|
146 |
+
start = time.perf_counter()
|
147 |
+
finetune_engine.finetune()
|
148 |
+
end = time.perf_counter() - start
|
149 |
+
print(f"GPU processing lasted {end:.2f} seconds")
|
150 |
+
|
151 |
+
print(os.listdir('/root/models'))
|
152 |
+
stub.volume.commit() # Persist changes, ie the finetumed model
|
153 |
+
|
154 |
+
# TODO SHARE THE MODEL ON HUGGINGFACE
|
155 |
+
# https://huggingface.co/docs/transformers/v4.15.0/model_sharing
|
156 |
+
|
157 |
+
folder_to_zip = model_outpath
|
158 |
+
# Zip the contents of the folder at 'folder_path' and return a BytesIO object.
|
159 |
+
bytes_buffer = io.BytesIO()
|
160 |
+
|
161 |
+
with zipfile.ZipFile(bytes_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
|
162 |
+
for file_path in glob.glob(folder_to_zip + "/**", recursive=True):
|
163 |
+
print(f"Processed file {file_path}")
|
164 |
+
zip_file.write(file_path, os.path.relpath(file_path, start=folder_to_zip))
|
165 |
+
|
166 |
+
# Move the pointer to the start of the BytesIO buffer before returning
|
167 |
+
bytes_buffer.seek(0)
|
168 |
+
# You can now return this zipped_folder object, write it to a file, send it over a network, etc.
|
169 |
+
# Replace with the path to the folder you want to zip
|
170 |
+
zippedio = bytes_buffer
|
171 |
+
|
172 |
+
return zippedio
|
173 |
+
except:
|
174 |
+
return "Finetuning failed"
|
175 |
+
|
176 |
+
|
177 |
+
@stub.local_entrypoint()
|
178 |
+
def test_method(content_splits=[["a"]]):
|
179 |
+
output = encode_content_splits.remote(content_splits)
|
180 |
+
return output
|
181 |
+
|
182 |
+
# deploy it with
|
183 |
+
# modal token set --token-id ak-xxxxxx --token-secret as-xxxxx # given when we create a new token
|
184 |
+
# modal deploy podcast/1/backend.py
|
185 |
+
# View Deployment: https://modal.com/apps/jpbianchi/falcon_hackaton-project <<< use this project name
|
class_templates.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
impact_theory_class_properties = [
|
2 |
+
{'name': 'title',
|
3 |
+
'dataType': ['text'],
|
4 |
+
'indexFilterable': True,
|
5 |
+
'indexSearchable': True},
|
6 |
+
{'name': 'video_id',
|
7 |
+
'dataType': ['text'],
|
8 |
+
'indexFilterable': True,
|
9 |
+
'indexSearchable': False},
|
10 |
+
{'name': 'length',
|
11 |
+
'dataType': ['int'],
|
12 |
+
'indexFilterable': True,
|
13 |
+
'indexSearchable': False},
|
14 |
+
{'name': 'thumbnail_url',
|
15 |
+
'dataType': ['text'],
|
16 |
+
'indexFilterable': False,
|
17 |
+
'indexSearchable': False},
|
18 |
+
{'name': 'views',
|
19 |
+
'dataType': ['int'],
|
20 |
+
'indexFilterable': True,
|
21 |
+
'indexSearchable': False},
|
22 |
+
{'name': 'episode_url',
|
23 |
+
'dataType': ['text'],
|
24 |
+
'indexFilterable': False,
|
25 |
+
'indexSearchable': False},
|
26 |
+
{'name': 'doc_id',
|
27 |
+
'dataType': ['text'],
|
28 |
+
'indexFilterable': True,
|
29 |
+
'indexSearchable': False},
|
30 |
+
{'name': 'guest',
|
31 |
+
'dataType': ['text'],
|
32 |
+
'indexFilterable': True,
|
33 |
+
'indexSearchable': True},
|
34 |
+
{'name': 'summary',
|
35 |
+
'dataType': ['text'],
|
36 |
+
'indexFilterable': False,
|
37 |
+
'indexSearchable': True},
|
38 |
+
{'name': 'content',
|
39 |
+
'dataType': ['text'],
|
40 |
+
'indexFilterable': False,
|
41 |
+
'indexSearchable': True},
|
42 |
+
]
|
43 |
+
# {'name': 'publish_date',
|
44 |
+
# 'dataType': ['date'],
|
45 |
+
# 'indexFilterable': True,
|
46 |
+
# 'indexSearchable': False},
|
finetune_backend.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import os, time, io, zipfile
|
3 |
+
from preprocessing import FileIO
|
4 |
+
import shutil
|
5 |
+
import modal
|
6 |
+
from llama_index.finetuning import EmbeddingQAFinetuneDataset
|
7 |
+
|
8 |
+
from dotenv import load_dotenv, find_dotenv
|
9 |
+
env = load_dotenv(find_dotenv('env'), override=True)
|
10 |
+
|
11 |
+
#%%
|
12 |
+
training_path = 'data/training_data_300.json'
|
13 |
+
valid_path = 'data/validation_data_100.json'
|
14 |
+
|
15 |
+
training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
|
16 |
+
valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)
|
17 |
+
|
18 |
+
def finetune(model='all-mpnet-base-v2', savemodel=False, outpath='.'):
|
19 |
+
""" Finetunes a model on Modal GPU A100.
|
20 |
+
The model is saved in /root/models on a Modal volume
|
21 |
+
and can be stored locally.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
model (str): the Sentence Transformer model name
|
25 |
+
savemodel (bool, optional): whether to save the model or not.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
path of the saved model (when saved)
|
29 |
+
"""
|
30 |
+
f = modal.Function.lookup("vector-search-project", "finetune")
|
31 |
+
model = model.replace('/','')
|
32 |
+
|
33 |
+
if 'sentence-transformers' not in model:
|
34 |
+
model = f"sentence-transformers/{model}"
|
35 |
+
|
36 |
+
fullpath = os.path.join(outpath, f"finetuned-{model}-300")
|
37 |
+
|
38 |
+
if os.path.exists(fullpath):
|
39 |
+
msg = "Model already exists!"
|
40 |
+
print(msg)
|
41 |
+
return msg
|
42 |
+
|
43 |
+
start = time.perf_counter()
|
44 |
+
finetuned_model = f.remote(training_path, valid_path, model_id=model)
|
45 |
+
|
46 |
+
end = time.perf_counter() - start
|
47 |
+
print(f"Finetuning with GPU lasted {end:.2f} seconds")
|
48 |
+
|
49 |
+
if savemodel:
|
50 |
+
|
51 |
+
with open(fullpath, 'wb') as file:
|
52 |
+
# Write the contents of the BytesIO object to a new file
|
53 |
+
file.write(finetuned_model.getbuffer())
|
54 |
+
print(f"Model saved in {fullpath}")
|
55 |
+
return fullpath
|
helpers.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Dict, Any
|
2 |
+
import time
|
3 |
+
from tqdm.notebook import tqdm
|
4 |
+
from rich import print
|
5 |
+
|
6 |
+
from retrieval_evaluation import calc_hit_rate_scores, calc_mrr_scores, record_results, add_params
|
7 |
+
from llama_index.finetuning import EmbeddingQAFinetuneDataset
|
8 |
+
from weaviate_interface import WeaviateClient
|
9 |
+
|
10 |
+
|
11 |
+
def retrieval_evaluation(dataset: EmbeddingQAFinetuneDataset,
|
12 |
+
class_name: str,
|
13 |
+
retriever: WeaviateClient,
|
14 |
+
retrieve_limit: int=5,
|
15 |
+
chunk_size: int=256,
|
16 |
+
hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
|
17 |
+
display_properties: List[str]=['doc_id', 'guest', 'content'],
|
18 |
+
dir_outpath: str='./eval_results',
|
19 |
+
include_miss_info: bool=False,
|
20 |
+
user_def_params: Dict[str,Any]=None
|
21 |
+
) -> Dict[str, str|int|float]:
|
22 |
+
'''
|
23 |
+
Given a dataset and a retriever evaluate the performance of the retriever. Returns a dict of kw and vector
|
24 |
+
hit rates and mrr scores. If inlude_miss_info is True, will also return a list of kw and vector responses
|
25 |
+
and their associated queries that did not return a hit, for deeper analysis. Text file with results output
|
26 |
+
is automatically saved in the dir_outpath directory.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
-----
|
30 |
+
dataset: EmbeddingQAFinetuneDataset
|
31 |
+
Dataset to be used for evaluation
|
32 |
+
class_name: str
|
33 |
+
Name of Class on Weaviate host to be used for retrieval
|
34 |
+
retriever: WeaviateClient
|
35 |
+
WeaviateClient object to be used for retrieval
|
36 |
+
retrieve_limit: int=5
|
37 |
+
Number of documents to retrieve from Weaviate host
|
38 |
+
chunk_size: int=256
|
39 |
+
Number of tokens used to chunk text. This value is purely for results
|
40 |
+
recording purposes and does not affect results.
|
41 |
+
display_properties: List[str]=['doc_id', 'content']
|
42 |
+
List of properties to be returned from Weaviate host for display in response
|
43 |
+
dir_outpath: str='./eval_results'
|
44 |
+
Directory path for saving results. Directory will be created if it does not
|
45 |
+
already exist.
|
46 |
+
include_miss_info: bool=False
|
47 |
+
Option to include queries and their associated kw and vector response values
|
48 |
+
for queries that are "total misses"
|
49 |
+
user_def_params : dict=None
|
50 |
+
Option for user to pass in a dictionary of user-defined parameters and their values.
|
51 |
+
'''
|
52 |
+
|
53 |
+
results_dict = {'n':retrieve_limit,
|
54 |
+
'Retriever': retriever.model_name_or_path,
|
55 |
+
'chunk_size': chunk_size,
|
56 |
+
'kw_hit_rate': 0,
|
57 |
+
'kw_mrr': 0,
|
58 |
+
'vector_hit_rate': 0,
|
59 |
+
'vector_mrr': 0,
|
60 |
+
'total_misses': 0,
|
61 |
+
'total_questions':0
|
62 |
+
}
|
63 |
+
#add hnsw configs and user defined params (if any)
|
64 |
+
results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
|
65 |
+
|
66 |
+
start = time.perf_counter()
|
67 |
+
miss_info = []
|
68 |
+
for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
|
69 |
+
results_dict['total_questions'] += 1
|
70 |
+
hit = False
|
71 |
+
#make Keyword, Vector, and Hybrid calls to Weaviate host
|
72 |
+
try:
|
73 |
+
kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
|
74 |
+
vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
|
75 |
+
|
76 |
+
#collect doc_ids and position of doc_ids to check for document matches
|
77 |
+
kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response, 1)}
|
78 |
+
vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response, 1)}
|
79 |
+
|
80 |
+
#extract doc_id for scoring purposes
|
81 |
+
doc_id = dataset.relevant_docs[query_id][0]
|
82 |
+
|
83 |
+
#increment hit_rate counters and mrr scores
|
84 |
+
if doc_id in kw_doc_ids:
|
85 |
+
results_dict['kw_hit_rate'] += 1
|
86 |
+
results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
|
87 |
+
hit = True
|
88 |
+
if doc_id in vector_doc_ids:
|
89 |
+
results_dict['vector_hit_rate'] += 1
|
90 |
+
results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
|
91 |
+
hit = True
|
92 |
+
|
93 |
+
# if no hits, let's capture that
|
94 |
+
if not hit:
|
95 |
+
results_dict['total_misses'] += 1
|
96 |
+
miss_info.append({'query': q, 'kw_response': kw_response, 'vector_response': vector_response})
|
97 |
+
except Exception as e:
|
98 |
+
print(e)
|
99 |
+
continue
|
100 |
+
|
101 |
+
|
102 |
+
#use raw counts to calculate final scores
|
103 |
+
calc_hit_rate_scores(results_dict)
|
104 |
+
calc_mrr_scores(results_dict)
|
105 |
+
|
106 |
+
end = time.perf_counter() - start
|
107 |
+
print(f'Total Processing Time: {round(end/60, 2)} minutes')
|
108 |
+
record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
|
109 |
+
|
110 |
+
if include_miss_info:
|
111 |
+
return results_dict, miss_info
|
112 |
+
return results_dict
|
llama_test.ipynb
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"Note: you may need to restart the kernel to use updated packages.\n",
|
13 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"%pip install huggingface_hub --q\n",
|
19 |
+
"%pip install ipywidgets --q"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": 3,
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"from transformers.pipelines.text_generation import TextGenerationPipeline\n",
|
29 |
+
"from transformers import AutoConfig\n",
|
30 |
+
"import transformers"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 2,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [
|
38 |
+
{
|
39 |
+
"data": {
|
40 |
+
"application/vnd.jupyter.widget-view+json": {
|
41 |
+
"model_id": "f9c842f1bd7146e5a4e4d517450531ee",
|
42 |
+
"version_major": 2,
|
43 |
+
"version_minor": 0
|
44 |
+
},
|
45 |
+
"text/plain": [
|
46 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
"metadata": {},
|
50 |
+
"output_type": "display_data"
|
51 |
+
}
|
52 |
+
],
|
53 |
+
"source": [
|
54 |
+
"from huggingface_hub import notebook_login\n",
|
55 |
+
"notebook_login() #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"
|
56 |
+
]
|
57 |
+
}
|
58 |
+
],
|
59 |
+
"metadata": {
|
60 |
+
"kernelspec": {
|
61 |
+
"display_name": "venv",
|
62 |
+
"language": "python",
|
63 |
+
"name": "python3"
|
64 |
+
},
|
65 |
+
"language_info": {
|
66 |
+
"codemirror_mode": {
|
67 |
+
"name": "ipython",
|
68 |
+
"version": 3
|
69 |
+
},
|
70 |
+
"file_extension": ".py",
|
71 |
+
"mimetype": "text/x-python",
|
72 |
+
"name": "python",
|
73 |
+
"nbconvert_exporter": "python",
|
74 |
+
"pygments_lexer": "ipython3",
|
75 |
+
"version": "3.11.5"
|
76 |
+
}
|
77 |
+
},
|
78 |
+
"nbformat": 4,
|
79 |
+
"nbformat_minor": 2
|
80 |
+
}
|
openai_interface.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from openai import OpenAI
|
3 |
+
from typing import List, Any, Tuple
|
4 |
+
from tqdm import tqdm
|
5 |
+
import streamlit as st
|
6 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
7 |
+
|
8 |
+
from dotenv import load_dotenv, find_dotenv
|
9 |
+
load_dotenv(find_dotenv('env'), override=True)
|
10 |
+
|
11 |
+
try:
|
12 |
+
api_key = st.secrets['secrets']['OPENAI_API_KEY']
|
13 |
+
except:
|
14 |
+
api_key = os.environ['OPENAI_API_KEY']
|
15 |
+
class GPT_Turbo:
|
16 |
+
|
17 |
+
def __init__(self, model: str="gpt-3.5-turbo-0613", api_key: str=api_key):
|
18 |
+
self.model = model
|
19 |
+
self.client = OpenAI(api_key=api_key)
|
20 |
+
|
21 |
+
def get_chat_completion(self,
|
22 |
+
prompt: str,
|
23 |
+
system_message: str='You are a helpful assistant.',
|
24 |
+
temperature: int=0,
|
25 |
+
max_tokens: int=500,
|
26 |
+
stream: bool=False,
|
27 |
+
show_response: bool=False
|
28 |
+
) -> str:
|
29 |
+
messages = [
|
30 |
+
{'role': 'system', 'content': system_message},
|
31 |
+
{'role': 'assistant', 'content': prompt}
|
32 |
+
]
|
33 |
+
|
34 |
+
response = self.client.chat.completions.create( model=self.model,
|
35 |
+
messages=messages,
|
36 |
+
temperature=temperature,
|
37 |
+
max_tokens=max_tokens,
|
38 |
+
stream=stream)
|
39 |
+
if show_response:
|
40 |
+
return response
|
41 |
+
return response.choices[0].message.content
|
42 |
+
|
43 |
+
def multi_thread_request(self,
|
44 |
+
filepath: str,
|
45 |
+
prompt: str,
|
46 |
+
content: List[str],
|
47 |
+
temperature: int=0
|
48 |
+
) -> List[Any]:
|
49 |
+
|
50 |
+
data = []
|
51 |
+
with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
|
52 |
+
futures = [exec.submit(self.get_completion_from_messages, [{'role': 'user','content': f'{prompt} ```{c}```'}], temperature, 500, False) for c in content]
|
53 |
+
with open(filepath, 'a') as f:
|
54 |
+
for future in as_completed(futures):
|
55 |
+
result = future.result()
|
56 |
+
if len(data) % 10 == 0:
|
57 |
+
print(f'{len(data)} of {len(content)} completed.')
|
58 |
+
if result:
|
59 |
+
data.append(result)
|
60 |
+
self.write_to_file(file_handle=f, data=result)
|
61 |
+
return [res for res in data if res]
|
62 |
+
|
63 |
+
def generate_question_context_pairs(self,
|
64 |
+
context_tuple: Tuple[str, str],
|
65 |
+
num_questions_per_chunk: int=2,
|
66 |
+
max_words_per_question: int=10
|
67 |
+
) -> List[str]:
|
68 |
+
|
69 |
+
doc_id, context = context_tuple
|
70 |
+
prompt = f'Context information is included below enclosed in triple backticks. Given the context information and not prior knowledge, generate questions based on the below query.\n\nYou are an end user querying for information about your favorite podcast. \
|
71 |
+
Your task is to setup {num_questions_per_chunk} questions that can be answered using only the given context. The questions should be diverse in nature across the document and be no longer than {max_words_per_question} words. \
|
72 |
+
Restrict the questions to the context information provided.\n\
|
73 |
+
```{context}```\n\n'
|
74 |
+
|
75 |
+
response = self.get_completion_from_messages(prompt=prompt, temperature=0, max_tokens=500, show_response=True)
|
76 |
+
questions = response.choices[0].message["content"]
|
77 |
+
return (doc_id, questions)
|
78 |
+
|
79 |
+
def batch_generate_question_context_pairs(self,
|
80 |
+
context_tuple_list: List[Tuple[str, str]],
|
81 |
+
num_questions_per_chunk: int=2,
|
82 |
+
max_words_per_question: int=10
|
83 |
+
) -> List[Tuple[str, str]]:
|
84 |
+
data = []
|
85 |
+
progress = tqdm(unit="Generated Questions", total=len(context_tuple_list))
|
86 |
+
with ThreadPoolExecutor(max_workers=2*os.cpu_count()) as exec:
|
87 |
+
futures = [exec.submit(self.generate_question_context_pairs, context_tuple, num_questions_per_chunk, max_words_per_question) for context_tuple in context_tuple_list]
|
88 |
+
for future in as_completed(futures):
|
89 |
+
result = future.result()
|
90 |
+
if result:
|
91 |
+
data.append(result)
|
92 |
+
progress.update(1)
|
93 |
+
return data
|
94 |
+
|
95 |
+
def get_embedding(self):
|
96 |
+
pass
|
97 |
+
|
98 |
+
def write_to_file(self, file_handle, data: str) -> None:
|
99 |
+
file_handle.write(data)
|
100 |
+
file_handle.write('\n')
|
preprocessing.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import pandas as pd
|
4 |
+
from typing import List, Union, Dict
|
5 |
+
from loguru import logger
|
6 |
+
import pandas as pd
|
7 |
+
import pathlib
|
8 |
+
|
9 |
+
|
10 |
+
## Set of helper functions that support data preprocessing
|
11 |
+
class FileIO:
|
12 |
+
'''
|
13 |
+
Convenience class for saving and loading data in parquet and
|
14 |
+
json formats to/from disk.
|
15 |
+
'''
|
16 |
+
|
17 |
+
def save_as_parquet(self,
|
18 |
+
file_path: str,
|
19 |
+
data: Union[List[dict], pd.DataFrame],
|
20 |
+
overwrite: bool=False) -> None:
|
21 |
+
'''
|
22 |
+
Saves DataFrame to disk as a parquet file. Removes the index.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
-----
|
26 |
+
file_path : str
|
27 |
+
Output path to save file, if not included "parquet" will be appended
|
28 |
+
as file extension.
|
29 |
+
data : Union[List[dict], pd.DataFrame]
|
30 |
+
Data to save as parquet file. If data is a list of dicts, it will be
|
31 |
+
converted to a DataFrame before saving.
|
32 |
+
overwrite : bool
|
33 |
+
Overwrite existing file if True, otherwise raise FileExistsError.
|
34 |
+
'''
|
35 |
+
if isinstance(data, list):
|
36 |
+
data = self._convert_toDataFrame(data)
|
37 |
+
if not file_path.endswith('parquet'):
|
38 |
+
file_path = self._rename_file_extension(file_path, 'parquet')
|
39 |
+
self._check_file_path(file_path, overwrite=overwrite)
|
40 |
+
data.to_parquet(file_path, index=False)
|
41 |
+
logger.info(f'DataFrame saved as parquet file here: {file_path}')
|
42 |
+
|
43 |
+
def _convert_toDataFrame(self, data: List[dict]) -> pd.DataFrame:
|
44 |
+
return pd.DataFrame().from_dict(data)
|
45 |
+
|
46 |
+
def _rename_file_extension(self, file_path: str, extension: str):
|
47 |
+
'''
|
48 |
+
Renames file with appropriate extension if file_path
|
49 |
+
does not already have correct extension.
|
50 |
+
'''
|
51 |
+
prefix = os.path.splitext(file_path)[0]
|
52 |
+
file_path = prefix + '.' + extension
|
53 |
+
return file_path
|
54 |
+
|
55 |
+
def _check_file_path(self, file_path: str, overwrite: bool) -> None:
|
56 |
+
'''
|
57 |
+
Checks for existence of file and overwrite permissions.
|
58 |
+
'''
|
59 |
+
if os.path.exists(file_path) and overwrite == False:
|
60 |
+
raise FileExistsError(f'File by name {file_path} already exists, try using another file name or set overwrite to True.')
|
61 |
+
elif os.path.exists(file_path):
|
62 |
+
os.remove(file_path)
|
63 |
+
else:
|
64 |
+
file_name = os.path.basename(file_path)
|
65 |
+
dir_structure = file_path.replace(file_name, '')
|
66 |
+
pathlib.Path(dir_structure).mkdir(parents=True, exist_ok=True)
|
67 |
+
|
68 |
+
def load_parquet(self, file_path: str, verbose: bool=True) -> List[dict]:
|
69 |
+
'''
|
70 |
+
Loads parquet from disk, converts to pandas DataFrame as intermediate
|
71 |
+
step and outputs a list of dicts (docs).
|
72 |
+
'''
|
73 |
+
df = pd.read_parquet(file_path)
|
74 |
+
vector_labels = ['content_vector', 'image_vector', 'content_embedding']
|
75 |
+
for label in vector_labels:
|
76 |
+
if label in df.columns:
|
77 |
+
df[label] = df[label].apply(lambda x: x.tolist())
|
78 |
+
if verbose:
|
79 |
+
memory_usage = round(df.memory_usage().sum()/(1024*1024),2)
|
80 |
+
print(f'Shape of data: {df.values.shape}')
|
81 |
+
print(f'Memory Usage: {memory_usage}+ MB')
|
82 |
+
list_of_dicts = df.to_dict('records')
|
83 |
+
return list_of_dicts
|
84 |
+
|
85 |
+
def load_json(self, file_path: str):
|
86 |
+
'''
|
87 |
+
Loads json file from disk.
|
88 |
+
'''
|
89 |
+
with open(file_path) as f:
|
90 |
+
data = json.load(f)
|
91 |
+
return data
|
92 |
+
|
93 |
+
def save_as_json(self,
|
94 |
+
file_path: str,
|
95 |
+
data: Union[List[dict], dict],
|
96 |
+
indent: int=4,
|
97 |
+
overwrite: bool=False
|
98 |
+
) -> None:
|
99 |
+
'''
|
100 |
+
Saves data to disk as a json file. Data can be a list of dicts or a single dict.
|
101 |
+
'''
|
102 |
+
if not file_path.endswith('json'):
|
103 |
+
file_path = self._rename_file_extension(file_path, 'json')
|
104 |
+
self._check_file_path(file_path, overwrite=overwrite)
|
105 |
+
with open(file_path, 'w') as f:
|
106 |
+
json.dump(data, f, indent=indent)
|
107 |
+
logger.info(f'Data saved as json file here: {file_path}')
|
108 |
+
|
109 |
+
class Utilities:
|
110 |
+
|
111 |
+
def create_video_url(self, video_id: str, playlist_id: str):
|
112 |
+
'''
|
113 |
+
Creates a hyperlink to a video episode given a video_id and playlist_id.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
-----
|
117 |
+
video_id : str
|
118 |
+
Video id of the episode from YouTube
|
119 |
+
playlist_id : str
|
120 |
+
Playlist id of the episode from YouTube
|
121 |
+
'''
|
122 |
+
return f'https://www.youtube.com/watch?v={video_id}&list={playlist_id}'
|
123 |
+
|
prompt_templates.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
question_answering_system = '''
|
2 |
+
You are the host of the show Impact Theory, and your name is Tom Bilyeu. The description of your show is as follows:
|
3 |
+
If you’re looking to thrive in uncertain times, achieve unprecedented goals, and improve the most meaningful aspects of your life, then Impact Theory is the show for you. Hosted by Tom Bilyeu, a voracious learner and hyper-successful entrepreneur, the show investigates and analyzes the most useful topics with the world’s most sought-after guests.
|
4 |
+
Bilyeu attacks each episode with a clear desire to further evolve the holistic skillset that allowed him to co-found the billion-dollar company Quest Nutrition, generate over half a billion organic views on his content, build a thriving marriage of over 20 years, and quantifiably improve the lives of over 10,000 people through his school, Impact Theory University.
|
5 |
+
Bilyeu’s insatiable hunger for knowledge gives the show urgency, relevance, and depth while leaving listeners with the knowledge, tools, and empowerment to take control of their lives and develop true personal power.
|
6 |
+
'''
|
7 |
+
|
8 |
+
question_answering_prompt_single = '''
|
9 |
+
Use the below context enclosed in triple back ticks to answer the question. If the context does not provide enough information to answer the question, then use any knowledge you have to answer the question.\n
|
10 |
+
```{context}```\n
|
11 |
+
Question:\n
|
12 |
+
{question}.\n
|
13 |
+
Answer:
|
14 |
+
'''
|
15 |
+
|
16 |
+
question_answering_prompt_series = '''
|
17 |
+
Your task is to synthesize and reason over a series of transcripts of an interview between Tom Bilyeu and his guest(s).
|
18 |
+
After your synthesis, use the series of transcripts to answer the below question. The series will be in the following format:\n
|
19 |
+
```
|
20 |
+
Show Summary: <summary>
|
21 |
+
Show Guest: <guest>
|
22 |
+
Transcript: <transcript>
|
23 |
+
```\n\n
|
24 |
+
Start Series:
|
25 |
+
```
|
26 |
+
{series}
|
27 |
+
```
|
28 |
+
Question:\n
|
29 |
+
{question}\n
|
30 |
+
Answer the question and provide reasoning if necessary to explain the answer.\n
|
31 |
+
If the context does not provide enough information to answer the question, then \n
|
32 |
+
state that you cannot answer the question with the provided context.\n
|
33 |
+
|
34 |
+
Answer:
|
35 |
+
'''
|
36 |
+
|
37 |
+
context_block = '''
|
38 |
+
Show Summary: {summary}
|
39 |
+
Show Guest: {guest}
|
40 |
+
Transcript: {transcript}
|
41 |
+
'''
|
42 |
+
|
43 |
+
qa_generation_prompt = '''
|
44 |
+
Impact Theory episode summary and episode guest are below:
|
45 |
+
|
46 |
+
---------------------
|
47 |
+
Summary: {summary}
|
48 |
+
---------------------
|
49 |
+
Guest: {guest}
|
50 |
+
---------------------
|
51 |
+
Given the Summary and Guest of the episode as context \
|
52 |
+
use the following randomly selected transcript section \
|
53 |
+
of the episode and not prior knowledge, generate questions that can \
|
54 |
+
be answered by the transcript section:
|
55 |
+
|
56 |
+
---------------------
|
57 |
+
Transcript: {transcript}
|
58 |
+
---------------------
|
59 |
+
|
60 |
+
Your task is to create {num_questions_per_chunk} questions that can \
|
61 |
+
only be answered given the previous context and transcript details. \
|
62 |
+
The question should randomly start with How, Why, or What.
|
63 |
+
'''
|
prompt_templates_luis.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
question_answering_system = '''
|
2 |
+
You are the host of the show Impact Theory, and your name is Tom Bilyeu. The description of your show is as follows:
|
3 |
+
If you’re looking to thrive in uncertain times, achieve unprecedented goals, and improve the most meaningful aspects of your life, then Impact Theory is the show for you. Hosted by Tom Bilyeu, a voracious learner and hyper-successful entrepreneur, the show investigates and analyzes the most useful topics with the world’s most sought-after guests.
|
4 |
+
Bilyeu attacks each episode with a clear desire to further evolve the holistic skillset that allowed him to co-found the billion-dollar company Quest Nutrition, generate over half a billion organic views on his content, build a thriving marriage of over 20 years, and quantifiably improve the lives of over 10,000 people through his school, Impact Theory University.
|
5 |
+
Bilyeu’s insatiable hunger for knowledge gives the show urgency, relevance, and depth while leaving listeners with the knowledge, tools, and empowerment to take control of their lives and develop true personal power.
|
6 |
+
'''
|
7 |
+
|
8 |
+
question_answering_prompt_single = '''
|
9 |
+
Use the below context enclosed in triple back ticks to answer the question. If the context does not provide enough information to answer the question, then use any knowledge you have to answer the question.\n
|
10 |
+
```{context}```\n
|
11 |
+
Question:\n
|
12 |
+
{question}.\n
|
13 |
+
Answer:
|
14 |
+
'''
|
15 |
+
|
16 |
+
question_answering_prompt_series = '''
|
17 |
+
Your task is to synthesize and reason over a series of transcripts of an interview between Tom Bilyeu and his guest(s).
|
18 |
+
After your synthesis, use the series of transcripts to answer the below question. The series will be in the following format:\n
|
19 |
+
```
|
20 |
+
Show Summary: <summary>
|
21 |
+
Show Guest: <guest>
|
22 |
+
Transcript: <transcript>
|
23 |
+
```\n\n
|
24 |
+
Start Series:
|
25 |
+
```
|
26 |
+
{series}
|
27 |
+
```
|
28 |
+
Question:\n
|
29 |
+
{question}\n
|
30 |
+
Answer the question and provide reasoning if necessary to explain the answer.\n
|
31 |
+
If the context does not provide enough information to answer the question, then \n
|
32 |
+
state that you cannot answer the question with the provided context.\n
|
33 |
+
|
34 |
+
Answer:
|
35 |
+
'''
|
36 |
+
|
37 |
+
context_block = '''
|
38 |
+
Show Summary: {summary}
|
39 |
+
Show Guest: {guest}
|
40 |
+
Transcript: {transcript}
|
41 |
+
'''
|
42 |
+
|
43 |
+
qa_generation_prompt = '''
|
44 |
+
Impact Theory episode summary and episode guest are below:
|
45 |
+
|
46 |
+
---------------------
|
47 |
+
Summary: {summary}
|
48 |
+
---------------------
|
49 |
+
Guest: {guest}
|
50 |
+
---------------------
|
51 |
+
Given the Summary and Guest of the episode as context \
|
52 |
+
use the following randomly selected transcript section \
|
53 |
+
of the episode and not prior knowledge, generate questions that can \
|
54 |
+
be answered by the transcript section:
|
55 |
+
|
56 |
+
---------------------
|
57 |
+
Transcript: {transcript}
|
58 |
+
---------------------
|
59 |
+
|
60 |
+
Your task is to create {num_questions_per_chunk} questions that can \
|
61 |
+
only be answered given the previous context and transcript details. \
|
62 |
+
The question should randomly start with How, Why, or What.
|
63 |
+
'''
|
readme2.md
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Welcome to Vector Search Applications with LLMs
|
2 |
+
This is the course repository for Vector Search Applications with LLMs taught by [Chris Sanchez](https://www.linkedin.com/in/excellenceisahabit/) with assistance from [Matias Weber](https://www.linkedin.com/in/matiasweber/).
|
3 |
+
The course is desgined to teach search and discovery industry best practices culminating in a demo Retrieval Augmented Generation (RAG) application. Along the way students will learn all of the components of a RAG system to include data preprocessing, embedding creation, vector database selection, indexing, retrieval systems, reranking, retrieval evaluation, question answering through an LLM and UI implementation through Streamlit.
|
4 |
+
|
5 |
+
# Prerequisites - Technical Experience
|
6 |
+
Students are expected to have the following technical skills prior to enrolling. Students who do not meet these prerequisites will likely have an overly challenging learning experience:
|
7 |
+
- Minimum of 1-year experience coding in Python. Skillsets should include programming using OOP, dictionary and list comprehensions, lambda functions, setting up virtual environments, comfortability with git version control.
|
8 |
+
- Professional or academic experience working with search engines.
|
9 |
+
- Ability to comfortably navigate the command line to include familiarity with docker.
|
10 |
+
- Nice to have but not strictly required:
|
11 |
+
- experience fine-tuning a ML model
|
12 |
+
- familiarity with the Streamlit API
|
13 |
+
- familiarity with making inference calls to a Generative LLM (OpenAI or Llama-2)
|
14 |
+
# Prerequisites - Administrative
|
15 |
+
1. Students will need access to their own compute environment, whether locally or remote. There are no hard requirements for RAM or CPU processing power, but in general the more punch the better.
|
16 |
+
2. Students will need accounts with the following organizations:
|
17 |
+
- Either an [OpenAI](https://openai.com) account **(RECOMMENDED)** or a [HuggingFace](https://huggingface.co/join) account. Students have the option of either using a paid LLM service (OpenAI) or using the open source `meta-llama/Llama-2-7b-chat-hf` model. Students choosing the latter option will first need to [register with Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) to request access to the Llama-2 model.
|
18 |
+
- An account with [weaviate.io](https://weaviate.io). The current iteration of this course will use Weaviate as a sparse and dense vector database. Weaviate offers free cloud instance cluster resources for 21 days (as of November 2023). **Students are advised to NOT CREATE** a Weaviate cloud cluster until the course officially starts.
|
19 |
+
- A standard [Github](https://github.com/) account in order to fork this repo, clone a copy, and submit commits to the fork as needed throughout the course.
|
20 |
+
|
21 |
+
# Setup
|
22 |
+
1. Fork this course repo (see upper right hand corner of the repo web page).
|
23 |
+
<img src="assets/forkbutton.png" alt="fork button" width="300" height="auto">
|
24 |
+
3. Clone a copy of the forked repo into the dev environment of your choice. Navigate into the cloned `vectorsearch-applications` directory.
|
25 |
+
4. Create a python virtual environment using your library of choice. Here's an example using [`conda`](https://docs.conda.io/projects/miniconda/en/latest/):
|
26 |
+
```
|
27 |
+
conda create --name impactenv -y python=3.10
|
28 |
+
```
|
29 |
+
4. Once the environment is created, activate the environment and install dependencies.
|
30 |
+
```
|
31 |
+
conda activate impactenv
|
32 |
+
|
33 |
+
pip install -r requirements.txt
|
34 |
+
```
|
35 |
+
5. Last but not least create a `.env` text file in your cloned repo. At a minimum, add the following environment variables:
|
36 |
+
```
|
37 |
+
OPENAI_API_KEY= "your OpenAI account API Key"
|
38 |
+
HF_TOKEN= "your HuggingFace account token" <--- Optional: not required if using OpenAI
|
39 |
+
WEAVIATE_API_KEY= "your Weaviate cluster API Key" <--- you will get this on Day One of the course
|
40 |
+
WEAVIATE_ENDPOINT= "your Weaviate cluster endpoint" <--- you will get this on Day One of the course
|
41 |
+
```
|
42 |
+
6. If you've made it this far, you are ready to start the course. Enjoy the process!
|
43 |
+
<img src="assets/getsome.jpg" alt="jocko" width="500" height="auto">
|
requirements.txt
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
beautifulsoup4==4.12.2
|
2 |
+
datasets==2.14.3
|
3 |
+
huggingface-hub==0.16.4
|
4 |
+
ipython==8.14.0
|
5 |
+
ipywidgets==8.1.1
|
6 |
+
jedi==0.19.0
|
7 |
+
jupyter-events==0.7.0
|
8 |
+
jupyter-lsp==2.2.0
|
9 |
+
jupyter_client==8.3.0
|
10 |
+
jupyter_core==5.3.1
|
11 |
+
jupyter_server==2.7.0
|
12 |
+
jupyter_server_terminals==0.4.4
|
13 |
+
jupyterlab==4.0.4
|
14 |
+
jupyterlab-pygments==0.2.2
|
15 |
+
jupyterlab-widgets==3.0.9
|
16 |
+
jupyterlab_server==2.24.0
|
17 |
+
langchain==0.0.310
|
18 |
+
langcodes==3.3.0
|
19 |
+
langsmith==0.0.43
|
20 |
+
llama-hub==0.0.47post1
|
21 |
+
llama-index==0.9.6.post1
|
22 |
+
loguru==0.7.0
|
23 |
+
matplotlib==3.7.2
|
24 |
+
matplotlib-inline==0.1.6
|
25 |
+
numpy==1.24.4
|
26 |
+
openai==1.3.5
|
27 |
+
pandas==2.0.3
|
28 |
+
protobuf==4.23.4
|
29 |
+
pyarrow==12.0.1
|
30 |
+
python-dotenv==1.0.0
|
31 |
+
rank-bm25==0.2.2
|
32 |
+
requests==2.31.0
|
33 |
+
requests-oauthlib==1.3.1
|
34 |
+
rich==13.7.0
|
35 |
+
sentence-transformers==2.2.2
|
36 |
+
streamlit==1.28.2
|
37 |
+
tiktoken==0.5.1
|
38 |
+
tokenizers==0.13.3
|
39 |
+
torch==2.0.1
|
40 |
+
tqdm==4.66.1
|
41 |
+
transformers==4.33.1
|
42 |
+
weaviate-client==3.25.3
|
43 |
+
polars>=0.19
|
44 |
+
plotly
|
45 |
+
angle-emb==0.1.5 # for UAE-Large-V1 model
|
46 |
+
streamlit-option-menu==0.3.6
|
47 |
+
hydralit_components==1.0.10
|
48 |
+
pathlib
|
49 |
+
gdown
|
50 |
+
modal
|
reranker.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import CrossEncoder
|
2 |
+
from torch.nn import Sigmoid
|
3 |
+
from typing import List, Union
|
4 |
+
import numpy as np
|
5 |
+
from loguru import logger
|
6 |
+
|
7 |
+
class ReRanker(CrossEncoder):
|
8 |
+
'''
|
9 |
+
Cross-Encoder models achieve higher performance than Bi-Encoders,
|
10 |
+
however, they do not scale well to large datasets. The lack of scalability
|
11 |
+
is due to the underlying cross-attention mechanism, which is computationally
|
12 |
+
expensive. Thus a Bi-Encoder is best used for 1st-stage document retrieval and
|
13 |
+
a Cross-Encoder is used to re-rank the retrieved documents.
|
14 |
+
|
15 |
+
https://www.sbert.net/examples/applications/cross-encoder/README.html
|
16 |
+
'''
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
model_name: str='cross-encoder/ms-marco-MiniLM-L-6-v2',
|
20 |
+
**kwargs
|
21 |
+
):
|
22 |
+
super().__init__(model_name=model_name,
|
23 |
+
**kwargs)
|
24 |
+
self.model_name = model_name
|
25 |
+
self.score_field = 'cross_score'
|
26 |
+
self.activation_fct = Sigmoid()
|
27 |
+
|
28 |
+
def _cross_encoder_score(self,
|
29 |
+
results: List[dict],
|
30 |
+
query: str,
|
31 |
+
hit_field: str='content',
|
32 |
+
apply_sigmoid: bool=True,
|
33 |
+
return_scores: bool=False
|
34 |
+
) -> Union[np.array, None]:
|
35 |
+
'''
|
36 |
+
Given a list of hits from a Retriever:
|
37 |
+
1. Scores hits by passing query and results through CrossEncoder model.
|
38 |
+
2. Adds cross-score key to results dictionary.
|
39 |
+
3. If desired returns np.array of Cross Encoder scores.
|
40 |
+
'''
|
41 |
+
activation_fct = self.activation_fct if apply_sigmoid else None
|
42 |
+
#build query/content list
|
43 |
+
cross_inp = [[query, hit[hit_field]] for hit in results]
|
44 |
+
#get scores
|
45 |
+
cross_scores = self.predict(cross_inp, activation_fct=activation_fct)
|
46 |
+
for i, result in enumerate(results):
|
47 |
+
result[self.score_field]=cross_scores[i]
|
48 |
+
|
49 |
+
if return_scores:return cross_scores
|
50 |
+
|
51 |
+
def rerank(self,
|
52 |
+
results: List[dict],
|
53 |
+
query: str,
|
54 |
+
top_k: int=10,
|
55 |
+
apply_sigmoid: bool=True,
|
56 |
+
threshold: float=None
|
57 |
+
) -> List[dict]:
|
58 |
+
'''
|
59 |
+
Given a list of hits from a Retriever:
|
60 |
+
1. Scores hits by passing query and results through CrossEncoder model.
|
61 |
+
2. Adds cross_score key to results dictionary.
|
62 |
+
3. Returns reranked results limited by either a threshold value or top_k.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
-----
|
66 |
+
results : List[dict]
|
67 |
+
List of results from the Weaviate client
|
68 |
+
query : str
|
69 |
+
User query
|
70 |
+
top_k : int=10
|
71 |
+
Number of results to return
|
72 |
+
apply_sigmoid : bool=True
|
73 |
+
Whether to apply sigmoid activation to cross-encoder scores. If False,
|
74 |
+
returns raw cross-encoder scores (logits).
|
75 |
+
threshold : float=None
|
76 |
+
Minimum cross-encoder score to return. If no hits are above threshold,
|
77 |
+
returns top_k hits.
|
78 |
+
'''
|
79 |
+
# Sort results by the cross-encoder scores
|
80 |
+
self._cross_encoder_score(results=results, query=query, apply_sigmoid=apply_sigmoid)
|
81 |
+
|
82 |
+
sorted_hits = sorted(results, key=lambda x: x[self.score_field], reverse=True)
|
83 |
+
if threshold or threshold == 0:
|
84 |
+
filtered_hits = [hit for hit in sorted_hits if hit[self.score_field] >= threshold]
|
85 |
+
if not any(filtered_hits):
|
86 |
+
logger.warning(f'No hits above threshold {threshold}. Returning top {top_k} hits.')
|
87 |
+
return sorted_hits[:top_k]
|
88 |
+
return filtered_hits
|
89 |
+
return sorted_hits[:top_k]
|
retrieval_evaluation.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#external files
|
2 |
+
from openai_interface import GPT_Turbo
|
3 |
+
from weaviate_interface import WeaviateClient
|
4 |
+
from llama_index.finetuning import EmbeddingQAFinetuneDataset
|
5 |
+
from prompt_templates import qa_generation_prompt
|
6 |
+
from reranker import ReRanker
|
7 |
+
|
8 |
+
#standard library imports
|
9 |
+
import json
|
10 |
+
import time
|
11 |
+
import uuid
|
12 |
+
import os
|
13 |
+
import re
|
14 |
+
import random
|
15 |
+
from datetime import datetime
|
16 |
+
from typing import List, Dict, Tuple, Union, Literal
|
17 |
+
|
18 |
+
#misc
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
|
22 |
+
class QueryContextGenerator:
|
23 |
+
'''
|
24 |
+
Class designed for the generation of query/context pairs using a
|
25 |
+
Generative LLM. The LLM is used to generate questions from a given
|
26 |
+
corpus of text. The query/context pairs can be used to fine-tune
|
27 |
+
an embedding model using a MultipleNegativesRankingLoss loss function
|
28 |
+
or can be used to create evaluation datasets for retrieval models.
|
29 |
+
'''
|
30 |
+
def __init__(self, openai_key: str, model_id: str='gpt-3.5-turbo-0613'):
|
31 |
+
self.llm = GPT_Turbo(model=model_id, api_key=openai_key)
|
32 |
+
|
33 |
+
def clean_validate_data(self,
|
34 |
+
data: List[dict],
|
35 |
+
valid_fields: List[str]=['content', 'summary', 'guest', 'doc_id'],
|
36 |
+
total_chars: int=950
|
37 |
+
) -> List[dict]:
|
38 |
+
'''
|
39 |
+
Strip original data chunks so they only contain valid_fields.
|
40 |
+
Remove any chunks less than total_chars in size. Prevents LLM
|
41 |
+
from asking questions from sparse content.
|
42 |
+
'''
|
43 |
+
clean_docs = [{k:v for k,v in d.items() if k in valid_fields} for d in data]
|
44 |
+
valid_docs = [d for d in clean_docs if len(d['content']) > total_chars]
|
45 |
+
return valid_docs
|
46 |
+
|
47 |
+
def train_val_split(self,
|
48 |
+
data: List[dict],
|
49 |
+
n_train_questions: int,
|
50 |
+
n_val_questions: int,
|
51 |
+
n_questions_per_chunk: int=2,
|
52 |
+
total_chars: int=950):
|
53 |
+
'''
|
54 |
+
Splits corpus into training and validation sets. Training and
|
55 |
+
validation samples are randomly selected from the corpus. total_chars
|
56 |
+
parameter is set based on pre-analysis of average doc length in the
|
57 |
+
training corpus.
|
58 |
+
'''
|
59 |
+
clean_data = self.clean_validate_data(data, total_chars=total_chars)
|
60 |
+
random.shuffle(clean_data)
|
61 |
+
train_index = n_train_questions//n_questions_per_chunk
|
62 |
+
valid_index = n_val_questions//n_questions_per_chunk
|
63 |
+
end_index = valid_index + train_index
|
64 |
+
if end_index > len(clean_data):
|
65 |
+
raise ValueError('Cannot create dataset with desired number of questions, try using a larger dataset')
|
66 |
+
train_data = clean_data[:train_index]
|
67 |
+
valid_data = clean_data[train_index:end_index]
|
68 |
+
print(f'Length Training Data: {len(train_data)}')
|
69 |
+
print(f'Length Validation Data: {len(valid_data)}')
|
70 |
+
return train_data, valid_data
|
71 |
+
|
72 |
+
def generate_qa_embedding_pairs(
|
73 |
+
self,
|
74 |
+
data: List[dict],
|
75 |
+
generate_prompt_tmpl: str=None,
|
76 |
+
num_questions_per_chunk: int = 2,
|
77 |
+
) -> EmbeddingQAFinetuneDataset:
|
78 |
+
"""
|
79 |
+
Generate query/context pairs from a list of documents. The query/context pairs
|
80 |
+
can be used for fine-tuning an embedding model using a MultipleNegativesRankingLoss
|
81 |
+
or can be used to create an evaluation dataset for retrieval models.
|
82 |
+
|
83 |
+
This function was adapted for this course from the llama_index.finetuning.common module:
|
84 |
+
https://github.com/run-llama/llama_index/blob/main/llama_index/finetuning/embeddings/common.py
|
85 |
+
"""
|
86 |
+
generate_prompt_tmpl = qa_generation_prompt if not generate_prompt_tmpl else generate_prompt_tmpl
|
87 |
+
queries = {}
|
88 |
+
relevant_docs = {}
|
89 |
+
corpus = {chunk['doc_id'] : chunk['content'] for chunk in data}
|
90 |
+
for chunk in tqdm(data):
|
91 |
+
summary = chunk['summary']
|
92 |
+
guest = chunk['guest']
|
93 |
+
transcript = chunk['content']
|
94 |
+
node_id = chunk['doc_id']
|
95 |
+
query = generate_prompt_tmpl.format(summary=summary,
|
96 |
+
guest=guest,
|
97 |
+
transcript=transcript,
|
98 |
+
num_questions_per_chunk=num_questions_per_chunk)
|
99 |
+
try:
|
100 |
+
response = self.llm.get_chat_completion(prompt=query, temperature=0.1, max_tokens=100)
|
101 |
+
except Exception as e:
|
102 |
+
print(e)
|
103 |
+
continue
|
104 |
+
result = str(response).strip().split("\n")
|
105 |
+
questions = [
|
106 |
+
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
107 |
+
]
|
108 |
+
questions = [question for question in questions if len(question) > 0]
|
109 |
+
|
110 |
+
for question in questions:
|
111 |
+
question_id = str(uuid.uuid4())
|
112 |
+
queries[question_id] = question
|
113 |
+
relevant_docs[question_id] = [node_id]
|
114 |
+
|
115 |
+
# construct dataset
|
116 |
+
return EmbeddingQAFinetuneDataset(
|
117 |
+
queries=queries, corpus=corpus, relevant_docs=relevant_docs
|
118 |
+
)
|
119 |
+
|
120 |
+
def execute_evaluation(dataset: EmbeddingQAFinetuneDataset,
|
121 |
+
class_name: str,
|
122 |
+
retriever: WeaviateClient,
|
123 |
+
reranker: ReRanker=None,
|
124 |
+
alpha: float=0.5,
|
125 |
+
retrieve_limit: int=100,
|
126 |
+
top_k: int=5,
|
127 |
+
chunk_size: int=256,
|
128 |
+
hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef'],
|
129 |
+
search_type: Literal['kw', 'vector', 'hybrid', 'all']='all',
|
130 |
+
display_properties: List[str]=['doc_id', 'content'],
|
131 |
+
dir_outpath: str='./eval_results',
|
132 |
+
include_miss_info: bool=False,
|
133 |
+
user_def_params: dict=None
|
134 |
+
) -> Union[dict, Tuple[dict, List[dict]]]:
|
135 |
+
'''
|
136 |
+
Given a dataset, a retriever, and a reranker, evaluate the performance of the retriever and reranker.
|
137 |
+
Returns a dict of kw, vector, and hybrid hit rates and mrr scores. If inlude_miss_info is True, will
|
138 |
+
also return a list of kw and vector responses and their associated queries that did not return a hit.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
-----
|
142 |
+
dataset: EmbeddingQAFinetuneDataset
|
143 |
+
Dataset to be used for evaluation
|
144 |
+
class_name: str
|
145 |
+
Name of Class on Weaviate host to be used for retrieval
|
146 |
+
retriever: WeaviateClient
|
147 |
+
WeaviateClient object to be used for retrieval
|
148 |
+
reranker: ReRanker
|
149 |
+
ReRanker model to be used for results reranking
|
150 |
+
alpha: float=0.5
|
151 |
+
Weighting factor for BM25 and Vector search.
|
152 |
+
alpha can be any number from 0 to 1, defaulting to 0.5:
|
153 |
+
alpha = 0 executes a pure keyword search method (BM25)
|
154 |
+
alpha = 0.5 weighs the BM25 and vector methods evenly
|
155 |
+
alpha = 1 executes a pure vector search method
|
156 |
+
retrieve_limit: int=5
|
157 |
+
Number of documents to retrieve from Weaviate host
|
158 |
+
top_k: int=5
|
159 |
+
Number of top results to evaluate
|
160 |
+
chunk_size: int=256
|
161 |
+
Number of tokens used to chunk text
|
162 |
+
hnsw_config_keys: List[str]=['maxConnections', 'efConstruction', 'ef']
|
163 |
+
List of keys to be used for retrieving HNSW Index parameters from Weaviate host
|
164 |
+
search_type: Literal['kw', 'vector', 'hybrid', 'all']='all'
|
165 |
+
Type of search to be evaluated. Options are 'kw', 'vector', 'hybrid', or 'all'
|
166 |
+
display_properties: List[str]=['doc_id', 'content']
|
167 |
+
List of properties to be returned from Weaviate host for display in response
|
168 |
+
dir_outpath: str='./eval_results'
|
169 |
+
Directory path for saving results. Directory will be created if it does not
|
170 |
+
already exist.
|
171 |
+
include_miss_info: bool=False
|
172 |
+
Option to include queries and their associated search response values
|
173 |
+
for queries that are "total misses"
|
174 |
+
user_def_params : dict=None
|
175 |
+
Option for user to pass in a dictionary of user-defined parameters and their values.
|
176 |
+
Will be automatically added to the results_dict if correct type is passed.
|
177 |
+
'''
|
178 |
+
|
179 |
+
reranker_name = reranker.model_name if reranker else "None"
|
180 |
+
|
181 |
+
results_dict = {'n':retrieve_limit,
|
182 |
+
'top_k': top_k,
|
183 |
+
'alpha': alpha,
|
184 |
+
'Retriever': retriever.model_name_or_path,
|
185 |
+
'Ranker': reranker_name,
|
186 |
+
'chunk_size': chunk_size,
|
187 |
+
'kw_hit_rate': 0,
|
188 |
+
'kw_mrr': 0,
|
189 |
+
'vector_hit_rate': 0,
|
190 |
+
'vector_mrr': 0,
|
191 |
+
'hybrid_hit_rate':0,
|
192 |
+
'hybrid_mrr': 0,
|
193 |
+
'total_misses': 0,
|
194 |
+
'total_questions':0
|
195 |
+
}
|
196 |
+
#add extra params to results_dict
|
197 |
+
results_dict = add_params(retriever, class_name, results_dict, user_def_params, hnsw_config_keys)
|
198 |
+
|
199 |
+
start = time.perf_counter()
|
200 |
+
miss_info = []
|
201 |
+
for query_id, q in tqdm(dataset.queries.items(), 'Queries'):
|
202 |
+
results_dict['total_questions'] += 1
|
203 |
+
hit = False
|
204 |
+
#make Keyword, Vector, and Hybrid calls to Weaviate host
|
205 |
+
try:
|
206 |
+
kw_response = retriever.keyword_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
|
207 |
+
vector_response = retriever.vector_search(request=q, class_name=class_name, limit=retrieve_limit, display_properties=display_properties)
|
208 |
+
hybrid_response = retriever.hybrid_search(request=q, class_name=class_name, alpha=alpha, limit=retrieve_limit, display_properties=display_properties)
|
209 |
+
#rerank returned responses if reranker is provided
|
210 |
+
if reranker:
|
211 |
+
kw_response = reranker.rerank(kw_response, q, top_k=top_k)
|
212 |
+
vector_response = reranker.rerank(vector_response, q, top_k=top_k)
|
213 |
+
hybrid_response = reranker.rerank(hybrid_response, q, top_k=top_k)
|
214 |
+
|
215 |
+
#collect doc_ids to check for document matches (include only results_top_k)
|
216 |
+
kw_doc_ids = {result['doc_id']:i for i, result in enumerate(kw_response[:top_k], 1)}
|
217 |
+
vector_doc_ids = {result['doc_id']:i for i, result in enumerate(vector_response[:top_k], 1)}
|
218 |
+
hybrid_doc_ids = {result['doc_id']:i for i, result in enumerate(hybrid_response[:top_k], 1)}
|
219 |
+
|
220 |
+
#extract doc_id for scoring purposes
|
221 |
+
doc_id = dataset.relevant_docs[query_id][0]
|
222 |
+
|
223 |
+
#increment hit_rate counters and mrr scores
|
224 |
+
if doc_id in kw_doc_ids:
|
225 |
+
results_dict['kw_hit_rate'] += 1
|
226 |
+
results_dict['kw_mrr'] += 1/kw_doc_ids[doc_id]
|
227 |
+
hit = True
|
228 |
+
if doc_id in vector_doc_ids:
|
229 |
+
results_dict['vector_hit_rate'] += 1
|
230 |
+
results_dict['vector_mrr'] += 1/vector_doc_ids[doc_id]
|
231 |
+
hit = True
|
232 |
+
if doc_id in hybrid_doc_ids:
|
233 |
+
results_dict['hybrid_hit_rate'] += 1
|
234 |
+
results_dict['hybrid_mrr'] += 1/hybrid_doc_ids[doc_id]
|
235 |
+
hit = True
|
236 |
+
# if no hits, let's capture that
|
237 |
+
if not hit:
|
238 |
+
results_dict['total_misses'] += 1
|
239 |
+
miss_info.append({'query': q,
|
240 |
+
'answer': dataset.corpus[doc_id],
|
241 |
+
'doc_id': doc_id,
|
242 |
+
'kw_response': kw_response,
|
243 |
+
'vector_response': vector_response,
|
244 |
+
'hybrid_response': hybrid_response})
|
245 |
+
except Exception as e:
|
246 |
+
print(e)
|
247 |
+
continue
|
248 |
+
|
249 |
+
#use raw counts to calculate final scores
|
250 |
+
calc_hit_rate_scores(results_dict, search_type=search_type)
|
251 |
+
calc_mrr_scores(results_dict, search_type=search_type)
|
252 |
+
|
253 |
+
end = time.perf_counter() - start
|
254 |
+
print(f'Total Processing Time: {round(end/60, 2)} minutes')
|
255 |
+
record_results(results_dict, chunk_size, dir_outpath=dir_outpath, as_text=True)
|
256 |
+
|
257 |
+
if include_miss_info:
|
258 |
+
return results_dict, miss_info
|
259 |
+
return results_dict
|
260 |
+
|
261 |
+
def calc_hit_rate_scores(results_dict: Dict[str, Union[str, int]],
|
262 |
+
search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
|
263 |
+
) -> None:
|
264 |
+
if search_type == 'all':
|
265 |
+
search_type = ['kw', 'vector', 'hybrid']
|
266 |
+
for prefix in search_type:
|
267 |
+
results_dict[f'{prefix}_hit_rate'] = round(results_dict[f'{prefix}_hit_rate']/results_dict['total_questions'],2)
|
268 |
+
|
269 |
+
def calc_mrr_scores(results_dict: Dict[str, Union[str, int]],
|
270 |
+
search_type: Literal['kw', 'vector', 'hybrid', 'all']=['kw', 'vector']
|
271 |
+
) -> None:
|
272 |
+
if search_type == 'all':
|
273 |
+
search_type = ['kw', 'vector', 'hybrid']
|
274 |
+
for prefix in search_type:
|
275 |
+
results_dict[f'{prefix}_mrr'] = round(results_dict[f'{prefix}_mrr']/results_dict['total_questions'],2)
|
276 |
+
|
277 |
+
def create_dir(dir_path: str) -> None:
|
278 |
+
'''
|
279 |
+
Checks if directory exists, and creates new directory
|
280 |
+
if it does not exist
|
281 |
+
'''
|
282 |
+
if not os.path.exists(dir_path):
|
283 |
+
os.makedirs(dir_path)
|
284 |
+
|
285 |
+
def record_results(results_dict: Dict[str, Union[str, int]],
|
286 |
+
chunk_size: int,
|
287 |
+
dir_outpath: str='./eval_results',
|
288 |
+
as_text: bool=False
|
289 |
+
) -> None:
|
290 |
+
'''
|
291 |
+
Write results to output file in either txt or json format
|
292 |
+
|
293 |
+
Args:
|
294 |
+
-----
|
295 |
+
results_dict: Dict[str, Union[str, int]]
|
296 |
+
Dictionary containing results of evaluation
|
297 |
+
chunk_size: int
|
298 |
+
Size of text chunks in tokens
|
299 |
+
dir_outpath: str
|
300 |
+
Path to output directory. Directory only, filename is hardcoded
|
301 |
+
as part of this function.
|
302 |
+
as_text: bool
|
303 |
+
If True, write results as text file. If False, write as json file.
|
304 |
+
'''
|
305 |
+
create_dir(dir_outpath)
|
306 |
+
time_marker = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
307 |
+
ext = 'txt' if as_text else 'json'
|
308 |
+
path = os.path.join(dir_outpath, f'retrieval_eval_{chunk_size}_{time_marker}.{ext}')
|
309 |
+
if as_text:
|
310 |
+
with open(path, 'a') as f:
|
311 |
+
f.write(f"{results_dict}\n")
|
312 |
+
else:
|
313 |
+
with open(path, 'w') as f:
|
314 |
+
json.dump(results_dict, f, indent=4)
|
315 |
+
|
316 |
+
def add_params(client: WeaviateClient,
|
317 |
+
class_name: str,
|
318 |
+
results_dict: dict,
|
319 |
+
param_options: dict,
|
320 |
+
hnsw_config_keys: List[str]
|
321 |
+
) -> dict:
|
322 |
+
hnsw_params = {k:v for k,v in client.show_class_config(class_name)['vectorIndexConfig'].items() if k in hnsw_config_keys}
|
323 |
+
if hnsw_params:
|
324 |
+
results_dict = {**results_dict, **hnsw_params}
|
325 |
+
if param_options and isinstance(param_options, dict):
|
326 |
+
results_dict = {**results_dict, **param_options}
|
327 |
+
return results_dict
|
328 |
+
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
|
unitesting_utils.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import urllib.request
|
3 |
+
|
4 |
+
def load_impact_theory_data():
|
5 |
+
'''
|
6 |
+
Loads impact_theory_data.json data by trying three options:
|
7 |
+
1. Assumes user is in Google Colab environment and loads file from content dir.
|
8 |
+
2. If 1st option doesn't work, assumes user is in course repo and loads from data dir.
|
9 |
+
3. If 2nd option doesn't work, assumes user does not have direct access to data so
|
10 |
+
downloads data direct from course repo.
|
11 |
+
'''
|
12 |
+
try:
|
13 |
+
path = '/content/impact_theory_data.json'
|
14 |
+
with open(path) as f:
|
15 |
+
data = json.load(f)
|
16 |
+
return data
|
17 |
+
except Exception:
|
18 |
+
print(f"Data not available at {path}")
|
19 |
+
try:
|
20 |
+
path = './data/impact_theory_data.json'
|
21 |
+
with open(path) as f:
|
22 |
+
data = json.load(f)
|
23 |
+
print(f'OK, data available at {path}')
|
24 |
+
return data
|
25 |
+
except Exception:
|
26 |
+
print(f'Data not available at {path}, downloading from source')
|
27 |
+
try:
|
28 |
+
with urllib.request.urlopen("https://ra.githubusercontent.com/americanthinker/vectorsearch-applications/main/data/impact_theory_data.json") as url:
|
29 |
+
data = json.load(url)
|
30 |
+
return data
|
31 |
+
except Exception:
|
32 |
+
print('Data cannot be loaded from source, please move data file to one of these paths to run this test:\n\
|
33 |
+
1. "/content/impact_theory_data.json" --> if you are in Google Colab\n\
|
34 |
+
2. "./data/impact_theory_data.json" --> if you are in a local environment\n')
|
utilities/install_kernel.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
echo Installing Jupyter kernel named $1 with display name $2
|
4 |
+
ipython kernel install --name "$1" --user --display-name $2
|
weaviate_interface.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from weaviate import Client, AuthApiKey
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from openai import OpenAI
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from typing import List, Union, Callable
|
6 |
+
from torch import cuda
|
7 |
+
from tqdm import tqdm
|
8 |
+
import time
|
9 |
+
|
10 |
+
class WeaviateClient(Client):
|
11 |
+
'''
|
12 |
+
A python native Weaviate Client class that encapsulates Weaviate functionalities
|
13 |
+
in one object. Several convenience methods are added for ease of use.
|
14 |
+
|
15 |
+
Args
|
16 |
+
----
|
17 |
+
api_key: str
|
18 |
+
The API key for the Weaviate Cloud Service (WCS) instance.
|
19 |
+
https://console.weaviate.cloud/dashboard
|
20 |
+
|
21 |
+
endpoint: str
|
22 |
+
The url endpoint for the Weaviate Cloud Service instance.
|
23 |
+
|
24 |
+
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2'
|
25 |
+
The name or path of the SentenceTransformer model to use for vector search.
|
26 |
+
Will also support OpenAI text-embedding-ada-002 model. This param enables
|
27 |
+
the use of most leading models on MTEB Leaderboard:
|
28 |
+
https://huggingface.co/spaces/mteb/leaderboard
|
29 |
+
openai_api_key: str=None
|
30 |
+
The API key for the OpenAI API. Only required if using OpenAI text-embedding-ada-002 model.
|
31 |
+
'''
|
32 |
+
def __init__(self,
|
33 |
+
api_key: str,
|
34 |
+
endpoint: str,
|
35 |
+
model_name_or_path: str='sentence-transformers/all-MiniLM-L6-v2',
|
36 |
+
openai_api_key: str=None,
|
37 |
+
**kwargs
|
38 |
+
):
|
39 |
+
auth_config = AuthApiKey(api_key=api_key)
|
40 |
+
super().__init__(auth_client_secret=auth_config,
|
41 |
+
url=endpoint,
|
42 |
+
**kwargs)
|
43 |
+
self.model_name_or_path = model_name_or_path
|
44 |
+
self.openai_model = False
|
45 |
+
if self.model_name_or_path == 'text-embedding-ada-002':
|
46 |
+
if not openai_api_key:
|
47 |
+
raise ValueError(f'OpenAI API key must be provided to use this model: {self.model_name_or_path}')
|
48 |
+
self.model = OpenAI(api_key=openai_api_key)
|
49 |
+
self.openai_model = True
|
50 |
+
else:
|
51 |
+
self.model = SentenceTransformer(self.model_name_or_path) if self.model_name_or_path else None
|
52 |
+
|
53 |
+
self.display_properties = ['title', 'video_id', 'length', 'thumbnail_url', 'views', 'episode_url', \
|
54 |
+
'doc_id', 'guest', 'content'] # 'playlist_id', 'channel_id', 'author'
|
55 |
+
|
56 |
+
def show_classes(self) -> Union[List[dict], str]:
|
57 |
+
'''
|
58 |
+
Shows all available classes (indexes) on the Weaviate instance.
|
59 |
+
'''
|
60 |
+
classes = self.cluster.get_nodes_status()[0]['shards']
|
61 |
+
if classes:
|
62 |
+
return [d['class'] for d in classes]
|
63 |
+
else:
|
64 |
+
return "No classes found on cluster."
|
65 |
+
|
66 |
+
def show_class_info(self) -> Union[List[dict], str]:
|
67 |
+
'''
|
68 |
+
Shows all information related to the classes (indexes) on the Weaviate instance.
|
69 |
+
'''
|
70 |
+
classes = self.cluster.get_nodes_status()[0]['shards']
|
71 |
+
if classes:
|
72 |
+
return [d for d in classes]
|
73 |
+
else:
|
74 |
+
return "No classes found on cluster."
|
75 |
+
|
76 |
+
def show_class_properties(self, class_name: str) -> Union[dict, str]:
|
77 |
+
'''
|
78 |
+
Shows all properties of a class (index) on the Weaviate instance.
|
79 |
+
'''
|
80 |
+
classes = self.schema.get()
|
81 |
+
if classes:
|
82 |
+
all_classes = classes['classes']
|
83 |
+
for d in all_classes:
|
84 |
+
if d['class'] == class_name:
|
85 |
+
return d['properties']
|
86 |
+
return f'Class "{class_name}" not found on host'
|
87 |
+
return f'No Classes found on host'
|
88 |
+
|
89 |
+
def show_class_config(self, class_name: str) -> Union[dict, str]:
|
90 |
+
'''
|
91 |
+
Shows all configuration of a class (index) on the Weaviate instance.
|
92 |
+
'''
|
93 |
+
classes = self.schema.get()
|
94 |
+
if classes:
|
95 |
+
all_classes = classes['classes']
|
96 |
+
for d in all_classes:
|
97 |
+
if d['class'] == class_name:
|
98 |
+
return d
|
99 |
+
return f'Class "{class_name}" not found on host'
|
100 |
+
return f'No Classes found on host'
|
101 |
+
|
102 |
+
def delete_class(self, class_name: str) -> str:
|
103 |
+
'''
|
104 |
+
Deletes a class (index) on the Weaviate instance, if it exists.
|
105 |
+
'''
|
106 |
+
available = self._check_class_avialability(class_name)
|
107 |
+
if isinstance(available, bool):
|
108 |
+
if available:
|
109 |
+
self.schema.delete_class(class_name)
|
110 |
+
not_deleted = self._check_class_avialability(class_name)
|
111 |
+
if isinstance(not_deleted, bool):
|
112 |
+
if not_deleted:
|
113 |
+
return f'Class "{class_name}" was not deleted. Try again.'
|
114 |
+
else:
|
115 |
+
return f'Class "{class_name}" deleted'
|
116 |
+
return f'Class "{class_name}" deleted and there are no longer any classes on host'
|
117 |
+
return f'Class "{class_name}" not found on host'
|
118 |
+
return available
|
119 |
+
|
120 |
+
def _check_class_avialability(self, class_name: str) -> Union[bool, str]:
|
121 |
+
'''
|
122 |
+
Checks if a class (index) exists on the Weaviate instance.
|
123 |
+
'''
|
124 |
+
classes = self.schema.get()
|
125 |
+
if classes:
|
126 |
+
all_classes = classes['classes']
|
127 |
+
for d in all_classes:
|
128 |
+
if d['class'] == class_name:
|
129 |
+
return True
|
130 |
+
return False
|
131 |
+
else:
|
132 |
+
return f'No Classes found on host'
|
133 |
+
|
134 |
+
def format_response(self,
|
135 |
+
response: dict,
|
136 |
+
class_name: str
|
137 |
+
) -> List[dict]:
|
138 |
+
'''
|
139 |
+
Formats json response from Weaviate into a list of dictionaries.
|
140 |
+
Expands _additional fields if present into top-level dictionary.
|
141 |
+
'''
|
142 |
+
if response.get('errors'):
|
143 |
+
return response['errors'][0]['message']
|
144 |
+
results = []
|
145 |
+
hits = response['data']['Get'][class_name]
|
146 |
+
for d in hits:
|
147 |
+
temp = {k:v for k,v in d.items() if k != '_additional'}
|
148 |
+
if d.get('_additional'):
|
149 |
+
for key in d['_additional']:
|
150 |
+
temp[key] = d['_additional'][key]
|
151 |
+
results.append(temp)
|
152 |
+
return results
|
153 |
+
|
154 |
+
def update_ef_value(self, class_name: str, ef_value: int) -> str:
|
155 |
+
'''
|
156 |
+
Updates ef_value for a class (index) on the Weaviate instance.
|
157 |
+
'''
|
158 |
+
self.schema.update_config(class_name=class_name, config={'vectorIndexConfig': {'ef': ef_value}})
|
159 |
+
print(f'ef_value updated to {ef_value} for class {class_name}')
|
160 |
+
return self.show_class_config(class_name)['vectorIndexConfig']
|
161 |
+
|
162 |
+
def keyword_search(self,
|
163 |
+
request: str,
|
164 |
+
class_name: str,
|
165 |
+
properties: List[str]=['content'],
|
166 |
+
limit: int=10,
|
167 |
+
where_filter: dict=None,
|
168 |
+
display_properties: List[str]=None,
|
169 |
+
return_raw: bool=False) -> Union[dict, List[dict]]:
|
170 |
+
'''
|
171 |
+
Executes Keyword (BM25) search.
|
172 |
+
|
173 |
+
Args
|
174 |
+
----
|
175 |
+
query: str
|
176 |
+
User query.
|
177 |
+
class_name: str
|
178 |
+
Class (index) to search.
|
179 |
+
properties: List[str]
|
180 |
+
List of properties to search across.
|
181 |
+
limit: int=10
|
182 |
+
Number of results to return.
|
183 |
+
display_properties: List[str]=None
|
184 |
+
List of properties to return in response.
|
185 |
+
If None, returns all properties.
|
186 |
+
return_raw: bool=False
|
187 |
+
If True, returns raw response from Weaviate.
|
188 |
+
'''
|
189 |
+
display_properties = display_properties if display_properties else self.display_properties
|
190 |
+
response = (self.query
|
191 |
+
.get(class_name, display_properties)
|
192 |
+
.with_bm25(query=request, properties=properties)
|
193 |
+
.with_additional(['score', "id"])
|
194 |
+
.with_limit(limit)
|
195 |
+
)
|
196 |
+
response = response.with_where(where_filter).do() if where_filter else response.do()
|
197 |
+
if return_raw:
|
198 |
+
return response
|
199 |
+
else:
|
200 |
+
return self.format_response(response, class_name)
|
201 |
+
|
202 |
+
def vector_search(self,
|
203 |
+
request: str,
|
204 |
+
class_name: str,
|
205 |
+
limit: int=10,
|
206 |
+
where_filter: dict=None,
|
207 |
+
display_properties: List[str]=None,
|
208 |
+
return_raw: bool=False,
|
209 |
+
device: str='cuda:0' if cuda.is_available() else 'cpu'
|
210 |
+
) -> Union[dict, List[dict]]:
|
211 |
+
'''
|
212 |
+
Executes vector search using embedding model defined on instantiation
|
213 |
+
of WeaviateClient instance.
|
214 |
+
|
215 |
+
Args
|
216 |
+
----
|
217 |
+
query: str
|
218 |
+
User query.
|
219 |
+
class_name: str
|
220 |
+
Class (index) to search.
|
221 |
+
limit: int=10
|
222 |
+
Number of results to return.
|
223 |
+
display_properties: List[str]=None
|
224 |
+
List of properties to return in response.
|
225 |
+
If None, returns all properties.
|
226 |
+
return_raw: bool=False
|
227 |
+
If True, returns raw response from Weaviate.
|
228 |
+
'''
|
229 |
+
display_properties = display_properties if display_properties else self.display_properties
|
230 |
+
query_vector = self._create_query_vector(request, device=device)
|
231 |
+
response = (
|
232 |
+
self.query
|
233 |
+
.get(class_name, display_properties)
|
234 |
+
.with_near_vector({"vector": query_vector})
|
235 |
+
.with_limit(limit)
|
236 |
+
.with_additional(['distance'])
|
237 |
+
)
|
238 |
+
response = response.with_where(where_filter).do() if where_filter else response.do()
|
239 |
+
if return_raw:
|
240 |
+
return response
|
241 |
+
else:
|
242 |
+
return self.format_response(response, class_name)
|
243 |
+
|
244 |
+
def _create_query_vector(self, query: str, device: str) -> List[float]:
|
245 |
+
'''
|
246 |
+
Creates embedding vector from text query.
|
247 |
+
'''
|
248 |
+
return self.get_openai_embedding(query) if self.openai_model else self.model.encode(query, device=device).tolist()
|
249 |
+
|
250 |
+
def get_openai_embedding(self, query: str) -> List[float]:
|
251 |
+
'''
|
252 |
+
Gets embedding from OpenAI API for query.
|
253 |
+
'''
|
254 |
+
embedding = self.model.embeddings.create(input=query, model='text-embedding-ada-002').model_dump()
|
255 |
+
if embedding:
|
256 |
+
return embedding['data'][0]['embedding']
|
257 |
+
else:
|
258 |
+
raise ValueError(f'No embedding found for query: {query}')
|
259 |
+
|
260 |
+
def hybrid_search(self,
|
261 |
+
request: str,
|
262 |
+
class_name: str,
|
263 |
+
properties: List[str]=['content'],
|
264 |
+
alpha: float=0.5,
|
265 |
+
limit: int=10,
|
266 |
+
where_filter: dict=None,
|
267 |
+
display_properties: List[str]=None,
|
268 |
+
return_raw: bool=False,
|
269 |
+
device: str='cuda:0' if cuda.is_available() else 'cpu'
|
270 |
+
) -> Union[dict, List[dict]]:
|
271 |
+
'''
|
272 |
+
Executes Hybrid (BM25 + Vector) search.
|
273 |
+
|
274 |
+
Args
|
275 |
+
----
|
276 |
+
query: str
|
277 |
+
User query.
|
278 |
+
class_name: str
|
279 |
+
Class (index) to search.
|
280 |
+
properties: List[str]
|
281 |
+
List of properties to search across (using BM25)
|
282 |
+
alpha: float=0.5
|
283 |
+
Weighting factor for BM25 and Vector search.
|
284 |
+
alpha can be any number from 0 to 1, defaulting to 0.5:
|
285 |
+
alpha = 0 executes a pure keyword search method (BM25)
|
286 |
+
alpha = 0.5 weighs the BM25 and vector methods evenly
|
287 |
+
alpha = 1 executes a pure vector search method
|
288 |
+
limit: int=10
|
289 |
+
Number of results to return.
|
290 |
+
display_properties: List[str]=None
|
291 |
+
List of properties to return in response.
|
292 |
+
If None, returns all properties.
|
293 |
+
return_raw: bool=False
|
294 |
+
If True, returns raw response from Weaviate.
|
295 |
+
'''
|
296 |
+
display_properties = display_properties if display_properties else self.display_properties
|
297 |
+
query_vector = self._create_query_vector(request, device=device)
|
298 |
+
response = (
|
299 |
+
self.query
|
300 |
+
.get(class_name, display_properties)
|
301 |
+
.with_hybrid(query=request,
|
302 |
+
alpha=alpha,
|
303 |
+
vector=query_vector,
|
304 |
+
properties=properties,
|
305 |
+
fusion_type='relativeScoreFusion') #hard coded option for now
|
306 |
+
.with_additional(["score", "explainScore"])
|
307 |
+
.with_limit(limit)
|
308 |
+
)
|
309 |
+
|
310 |
+
response = response.with_where(where_filter).do() if where_filter else response.do()
|
311 |
+
if return_raw:
|
312 |
+
return response
|
313 |
+
else:
|
314 |
+
return self.format_response(response, class_name)
|
315 |
+
|
316 |
+
|
317 |
+
class WeaviateIndexer:
|
318 |
+
|
319 |
+
def __init__(self,
|
320 |
+
client: WeaviateClient,
|
321 |
+
batch_size: int=150,
|
322 |
+
num_workers: int=4,
|
323 |
+
dynamic: bool=True,
|
324 |
+
creation_time: int=5,
|
325 |
+
timeout_retries: int=3,
|
326 |
+
connection_error_retries: int=3,
|
327 |
+
callback: Callable=None,
|
328 |
+
):
|
329 |
+
'''
|
330 |
+
Class designed to batch index documents into Weaviate. Instantiating
|
331 |
+
this class will automatically configure the Weaviate batch client.
|
332 |
+
'''
|
333 |
+
self._client = client
|
334 |
+
self._callback = callback if callback else self._default_callback
|
335 |
+
|
336 |
+
self._client.batch.configure(batch_size=batch_size,
|
337 |
+
num_workers=num_workers,
|
338 |
+
dynamic=dynamic,
|
339 |
+
creation_time=creation_time,
|
340 |
+
timeout_retries=timeout_retries,
|
341 |
+
connection_error_retries=connection_error_retries,
|
342 |
+
callback=self._callback
|
343 |
+
)
|
344 |
+
|
345 |
+
def _default_callback(self, results: dict):
|
346 |
+
"""
|
347 |
+
Check batch results for errors.
|
348 |
+
|
349 |
+
Parameters
|
350 |
+
----------
|
351 |
+
results : dict
|
352 |
+
The Weaviate batch creation return value.
|
353 |
+
"""
|
354 |
+
|
355 |
+
if results is not None:
|
356 |
+
for result in results:
|
357 |
+
if "result" in result and "errors" in result["result"]:
|
358 |
+
if "error" in result["result"]["errors"]:
|
359 |
+
print(result["result"])
|
360 |
+
|
361 |
+
def batch_index_data(self,
|
362 |
+
data: List[dict],
|
363 |
+
class_name: str,
|
364 |
+
vector_property: str='content_embedding'
|
365 |
+
) -> None:
|
366 |
+
'''
|
367 |
+
Batch function for fast indexing of data onto Weaviate cluster.
|
368 |
+
This method assumes that self._client.batch is already configured.
|
369 |
+
'''
|
370 |
+
start = time.perf_counter()
|
371 |
+
with self._client.batch as batch:
|
372 |
+
for d in tqdm(data):
|
373 |
+
|
374 |
+
#define single document
|
375 |
+
properties = {k:v for k,v in d.items() if k != vector_property}
|
376 |
+
try:
|
377 |
+
#add data object to batch
|
378 |
+
batch.add_data_object(
|
379 |
+
data_object=properties,
|
380 |
+
class_name=class_name,
|
381 |
+
vector=d[vector_property]
|
382 |
+
)
|
383 |
+
except Exception as e:
|
384 |
+
print(e)
|
385 |
+
continue
|
386 |
+
|
387 |
+
end = time.perf_counter() - start
|
388 |
+
|
389 |
+
print(f'Batch job completed in {round(end/60, 2)} minutes.')
|
390 |
+
class_info = self._client.show_class_info()
|
391 |
+
for i, c in enumerate(class_info):
|
392 |
+
if c['class'] == class_name:
|
393 |
+
print(class_info[i])
|
394 |
+
self._client.batch.shutdown()
|
395 |
+
|
396 |
+
@dataclass
|
397 |
+
class WhereFilter:
|
398 |
+
|
399 |
+
'''
|
400 |
+
Simplified interface for constructing a WhereFilter object.
|
401 |
+
|
402 |
+
Args
|
403 |
+
----
|
404 |
+
path: List[str]
|
405 |
+
List of properties to filter on.
|
406 |
+
operator: str
|
407 |
+
Operator to use for filtering. Options: ['And', 'Or', 'Equal', 'NotEqual',
|
408 |
+
'GreaterThan', 'GreaterThanEqual', 'LessThan', 'LessThanEqual', 'Like',
|
409 |
+
'WithinGeoRange', 'IsNull', 'ContainsAny', 'ContainsAll']
|
410 |
+
value[dataType]: Union[int, bool, str, float, datetime]
|
411 |
+
Value to filter on. The dataType suffix must match the data type of the
|
412 |
+
property being filtered on. At least and only one value type must be provided.
|
413 |
+
'''
|
414 |
+
path: List[str]
|
415 |
+
operator: str
|
416 |
+
valueInt: int=None
|
417 |
+
valueBoolean: bool=None
|
418 |
+
valueText: str=None
|
419 |
+
valueNumber: float=None
|
420 |
+
valueDate = None
|
421 |
+
|
422 |
+
def post_init(self):
|
423 |
+
operators = ['And', 'Or', 'Equal', 'NotEqual','GreaterThan', 'GreaterThanEqual', 'LessThan',\
|
424 |
+
'LessThanEqual', 'Like', 'WithinGeoRange', 'IsNull', 'ContainsAny', 'ContainsAll']
|
425 |
+
if self.operator not in operators:
|
426 |
+
raise ValueError(f'operator must be one of: {operators}, got {self.operator}')
|
427 |
+
values = [self.valueInt, self.valueBoolean, self.valueText, self.valueNumber, self.valueDate]
|
428 |
+
if not any(values):
|
429 |
+
raise ValueError('At least one value must be provided.')
|
430 |
+
if len(values) > 1:
|
431 |
+
raise ValueError('At most one value can be provided.')
|
432 |
+
|
433 |
+
def todict(self):
|
434 |
+
return {k:v for k,v in self.__dict__.items() if v is not None}
|