Duplicate from aadnk/whisper-webui
Browse filesCo-authored-by: Kristian Stangeland <[email protected]>
- .gitattributes +32 -0
- .gitignore +6 -0
- LICENSE.md +195 -0
- README.md +186 -0
- app-local.py +5 -0
- app-network.py +5 -0
- app-shared.py +5 -0
- app.py +627 -0
- cli.py +188 -0
- config.json5 +141 -0
- dockerfile +30 -0
- docs/colab.md +20 -0
- docs/options.md +134 -0
- docs/windows/install_win10_win11.pdf +3 -0
- requirements-fasterWhisper.txt +9 -0
- requirements-whisper.txt +9 -0
- requirements.txt +9 -0
- src/__init__.py +0 -0
- src/config.py +154 -0
- src/conversion/hf_converter.py +67 -0
- src/download.py +78 -0
- src/hooks/progressListener.py +8 -0
- src/hooks/subTaskProgressListener.py +37 -0
- src/hooks/whisperProgressHook.py +91 -0
- src/languages.py +147 -0
- src/modelCache.py +17 -0
- src/prompts/abstractPromptStrategy.py +73 -0
- src/prompts/jsonPromptStrategy.py +49 -0
- src/prompts/prependPromptStrategy.py +31 -0
- src/segments.py +55 -0
- src/source.py +80 -0
- src/utils.py +245 -0
- src/vad.py +568 -0
- src/vadParallel.py +298 -0
- src/whisper/abstractWhisperContainer.py +107 -0
- src/whisper/fasterWhisperContainer.py +207 -0
- src/whisper/whisperContainer.py +216 -0
- src/whisper/whisperFactory.py +19 -0
- tests/segments_test.py +48 -0
- tests/vad_test.py +66 -0
.gitattributes
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
.vscode/
|
4 |
+
flagged/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
LICENSE.md
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
==============
|
3 |
+
|
4 |
+
_Version 2.0, January 2004_
|
5 |
+
_<<http://www.apache.org/licenses/>>_
|
6 |
+
|
7 |
+
### Terms and Conditions for use, reproduction, and distribution
|
8 |
+
|
9 |
+
#### 1. Definitions
|
10 |
+
|
11 |
+
“License” shall mean the terms and conditions for use, reproduction, and
|
12 |
+
distribution as defined by Sections 1 through 9 of this document.
|
13 |
+
|
14 |
+
“Licensor” shall mean the copyright owner or entity authorized by the copyright
|
15 |
+
owner that is granting the License.
|
16 |
+
|
17 |
+
“Legal Entity” shall mean the union of the acting entity and all other entities
|
18 |
+
that control, are controlled by, or are under common control with that entity.
|
19 |
+
For the purposes of this definition, “control” means **(i)** the power, direct or
|
20 |
+
indirect, to cause the direction or management of such entity, whether by
|
21 |
+
contract or otherwise, or **(ii)** ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or **(iii)** beneficial ownership of such entity.
|
23 |
+
|
24 |
+
“You” (or “Your”) shall mean an individual or Legal Entity exercising
|
25 |
+
permissions granted by this License.
|
26 |
+
|
27 |
+
“Source” form shall mean the preferred form for making modifications, including
|
28 |
+
but not limited to software source code, documentation source, and configuration
|
29 |
+
files.
|
30 |
+
|
31 |
+
“Object” form shall mean any form resulting from mechanical transformation or
|
32 |
+
translation of a Source form, including but not limited to compiled object code,
|
33 |
+
generated documentation, and conversions to other media types.
|
34 |
+
|
35 |
+
“Work” shall mean the work of authorship, whether in Source or Object form, made
|
36 |
+
available under the License, as indicated by a copyright notice that is included
|
37 |
+
in or attached to the work (an example is provided in the Appendix below).
|
38 |
+
|
39 |
+
“Derivative Works” shall mean any work, whether in Source or Object form, that
|
40 |
+
is based on (or derived from) the Work and for which the editorial revisions,
|
41 |
+
annotations, elaborations, or other modifications represent, as a whole, an
|
42 |
+
original work of authorship. For the purposes of this License, Derivative Works
|
43 |
+
shall not include works that remain separable from, or merely link (or bind by
|
44 |
+
name) to the interfaces of, the Work and Derivative Works thereof.
|
45 |
+
|
46 |
+
“Contribution” shall mean any work of authorship, including the original version
|
47 |
+
of the Work and any modifications or additions to that Work or Derivative Works
|
48 |
+
thereof, that is intentionally submitted to Licensor for inclusion in the Work
|
49 |
+
by the copyright owner or by an individual or Legal Entity authorized to submit
|
50 |
+
on behalf of the copyright owner. For the purposes of this definition,
|
51 |
+
“submitted” means any form of electronic, verbal, or written communication sent
|
52 |
+
to the Licensor or its representatives, including but not limited to
|
53 |
+
communication on electronic mailing lists, source code control systems, and
|
54 |
+
issue tracking systems that are managed by, or on behalf of, the Licensor for
|
55 |
+
the purpose of discussing and improving the Work, but excluding communication
|
56 |
+
that is conspicuously marked or otherwise designated in writing by the copyright
|
57 |
+
owner as “Not a Contribution.”
|
58 |
+
|
59 |
+
“Contributor” shall mean Licensor and any individual or Legal Entity on behalf
|
60 |
+
of whom a Contribution has been received by Licensor and subsequently
|
61 |
+
incorporated within the Work.
|
62 |
+
|
63 |
+
#### 2. Grant of Copyright License
|
64 |
+
|
65 |
+
Subject to the terms and conditions of this License, each Contributor hereby
|
66 |
+
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
67 |
+
irrevocable copyright license to reproduce, prepare Derivative Works of,
|
68 |
+
publicly display, publicly perform, sublicense, and distribute the Work and such
|
69 |
+
Derivative Works in Source or Object form.
|
70 |
+
|
71 |
+
#### 3. Grant of Patent License
|
72 |
+
|
73 |
+
Subject to the terms and conditions of this License, each Contributor hereby
|
74 |
+
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
75 |
+
irrevocable (except as stated in this section) patent license to make, have
|
76 |
+
made, use, offer to sell, sell, import, and otherwise transfer the Work, where
|
77 |
+
such license applies only to those patent claims licensable by such Contributor
|
78 |
+
that are necessarily infringed by their Contribution(s) alone or by combination
|
79 |
+
of their Contribution(s) with the Work to which such Contribution(s) was
|
80 |
+
submitted. If You institute patent litigation against any entity (including a
|
81 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work or a
|
82 |
+
Contribution incorporated within the Work constitutes direct or contributory
|
83 |
+
patent infringement, then any patent licenses granted to You under this License
|
84 |
+
for that Work shall terminate as of the date such litigation is filed.
|
85 |
+
|
86 |
+
#### 4. Redistribution
|
87 |
+
|
88 |
+
You may reproduce and distribute copies of the Work or Derivative Works thereof
|
89 |
+
in any medium, with or without modifications, and in Source or Object form,
|
90 |
+
provided that You meet the following conditions:
|
91 |
+
|
92 |
+
* **(a)** You must give any other recipients of the Work or Derivative Works a copy of
|
93 |
+
this License; and
|
94 |
+
* **(b)** You must cause any modified files to carry prominent notices stating that You
|
95 |
+
changed the files; and
|
96 |
+
* **(c)** You must retain, in the Source form of any Derivative Works that You distribute,
|
97 |
+
all copyright, patent, trademark, and attribution notices from the Source form
|
98 |
+
of the Work, excluding those notices that do not pertain to any part of the
|
99 |
+
Derivative Works; and
|
100 |
+
* **(d)** If the Work includes a “NOTICE” text file as part of its distribution, then any
|
101 |
+
Derivative Works that You distribute must include a readable copy of the
|
102 |
+
attribution notices contained within such NOTICE file, excluding those notices
|
103 |
+
that do not pertain to any part of the Derivative Works, in at least one of the
|
104 |
+
following places: within a NOTICE text file distributed as part of the
|
105 |
+
Derivative Works; within the Source form or documentation, if provided along
|
106 |
+
with the Derivative Works; or, within a display generated by the Derivative
|
107 |
+
Works, if and wherever such third-party notices normally appear. The contents of
|
108 |
+
the NOTICE file are for informational purposes only and do not modify the
|
109 |
+
License. You may add Your own attribution notices within Derivative Works that
|
110 |
+
You distribute, alongside or as an addendum to the NOTICE text from the Work,
|
111 |
+
provided that such additional attribution notices cannot be construed as
|
112 |
+
modifying the License.
|
113 |
+
|
114 |
+
You may add Your own copyright statement to Your modifications and may provide
|
115 |
+
additional or different license terms and conditions for use, reproduction, or
|
116 |
+
distribution of Your modifications, or for any such Derivative Works as a whole,
|
117 |
+
provided Your use, reproduction, and distribution of the Work otherwise complies
|
118 |
+
with the conditions stated in this License.
|
119 |
+
|
120 |
+
#### 5. Submission of Contributions
|
121 |
+
|
122 |
+
Unless You explicitly state otherwise, any Contribution intentionally submitted
|
123 |
+
for inclusion in the Work by You to the Licensor shall be under the terms and
|
124 |
+
conditions of this License, without any additional terms or conditions.
|
125 |
+
Notwithstanding the above, nothing herein shall supersede or modify the terms of
|
126 |
+
any separate license agreement you may have executed with Licensor regarding
|
127 |
+
such Contributions.
|
128 |
+
|
129 |
+
#### 6. Trademarks
|
130 |
+
|
131 |
+
This License does not grant permission to use the trade names, trademarks,
|
132 |
+
service marks, or product names of the Licensor, except as required for
|
133 |
+
reasonable and customary use in describing the origin of the Work and
|
134 |
+
reproducing the content of the NOTICE file.
|
135 |
+
|
136 |
+
#### 7. Disclaimer of Warranty
|
137 |
+
|
138 |
+
Unless required by applicable law or agreed to in writing, Licensor provides the
|
139 |
+
Work (and each Contributor provides its Contributions) on an “AS IS” BASIS,
|
140 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
|
141 |
+
including, without limitation, any warranties or conditions of TITLE,
|
142 |
+
NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
|
143 |
+
solely responsible for determining the appropriateness of using or
|
144 |
+
redistributing the Work and assume any risks associated with Your exercise of
|
145 |
+
permissions under this License.
|
146 |
+
|
147 |
+
#### 8. Limitation of Liability
|
148 |
+
|
149 |
+
In no event and under no legal theory, whether in tort (including negligence),
|
150 |
+
contract, or otherwise, unless required by applicable law (such as deliberate
|
151 |
+
and grossly negligent acts) or agreed to in writing, shall any Contributor be
|
152 |
+
liable to You for damages, including any direct, indirect, special, incidental,
|
153 |
+
or consequential damages of any character arising as a result of this License or
|
154 |
+
out of the use or inability to use the Work (including but not limited to
|
155 |
+
damages for loss of goodwill, work stoppage, computer failure or malfunction, or
|
156 |
+
any and all other commercial damages or losses), even if such Contributor has
|
157 |
+
been advised of the possibility of such damages.
|
158 |
+
|
159 |
+
#### 9. Accepting Warranty or Additional Liability
|
160 |
+
|
161 |
+
While redistributing the Work or Derivative Works thereof, You may choose to
|
162 |
+
offer, and charge a fee for, acceptance of support, warranty, indemnity, or
|
163 |
+
other liability obligations and/or rights consistent with this License. However,
|
164 |
+
in accepting such obligations, You may act only on Your own behalf and on Your
|
165 |
+
sole responsibility, not on behalf of any other Contributor, and only if You
|
166 |
+
agree to indemnify, defend, and hold each Contributor harmless for any liability
|
167 |
+
incurred by, or claims asserted against, such Contributor by reason of your
|
168 |
+
accepting any such warranty or additional liability.
|
169 |
+
|
170 |
+
_END OF TERMS AND CONDITIONS_
|
171 |
+
|
172 |
+
### APPENDIX: How to apply the Apache License to your work
|
173 |
+
|
174 |
+
To apply the Apache License to your work, attach the following boilerplate
|
175 |
+
notice, with the fields enclosed by brackets `[]` replaced with your own
|
176 |
+
identifying information. (Don't include the brackets!) The text should be
|
177 |
+
enclosed in the appropriate comment syntax for the file format. We also
|
178 |
+
recommend that a file or class name and description of purpose be included on
|
179 |
+
the same “printed page” as the copyright notice for easier identification within
|
180 |
+
third-party archives.
|
181 |
+
|
182 |
+
Copyright [yyyy] [name of copyright owner]
|
183 |
+
|
184 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
185 |
+
you may not use this file except in compliance with the License.
|
186 |
+
You may obtain a copy of the License at
|
187 |
+
|
188 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
189 |
+
|
190 |
+
Unless required by applicable law or agreed to in writing, software
|
191 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
192 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
193 |
+
See the License for the specific language governing permissions and
|
194 |
+
limitations under the License.
|
195 |
+
|
README.md
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Whisper Webui
|
3 |
+
emoji: ⚡
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.23.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
duplicated_from: aadnk/whisper-webui
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
15 |
+
|
16 |
+
# Running Locally
|
17 |
+
|
18 |
+
To run this program locally, first install Python 3.9+ and Git. Then install Pytorch 10.1+ and all the other dependencies:
|
19 |
+
```
|
20 |
+
pip install -r requirements.txt
|
21 |
+
```
|
22 |
+
|
23 |
+
You can find detailed instructions for how to install this on Windows 10/11 [here (PDF)](docs/windows/install_win10_win11.pdf).
|
24 |
+
|
25 |
+
Finally, run the full version (no audio length restrictions) of the app with parallel CPU/GPU enabled:
|
26 |
+
```
|
27 |
+
python app.py --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
|
28 |
+
```
|
29 |
+
|
30 |
+
You can also run the CLI interface, which is similar to Whisper's own CLI but also supports the following additional arguments:
|
31 |
+
```
|
32 |
+
python cli.py \
|
33 |
+
[--vad {none,silero-vad,silero-vad-skip-gaps,silero-vad-expand-into-gaps,periodic-vad}] \
|
34 |
+
[--vad_merge_window VAD_MERGE_WINDOW] \
|
35 |
+
[--vad_max_merge_size VAD_MAX_MERGE_SIZE] \
|
36 |
+
[--vad_padding VAD_PADDING] \
|
37 |
+
[--vad_prompt_window VAD_PROMPT_WINDOW]
|
38 |
+
[--vad_cpu_cores NUMBER_OF_CORES]
|
39 |
+
[--vad_parallel_devices COMMA_DELIMITED_DEVICES]
|
40 |
+
[--auto_parallel BOOLEAN]
|
41 |
+
```
|
42 |
+
In addition, you may also use URL's in addition to file paths as input.
|
43 |
+
```
|
44 |
+
python cli.py --model large --vad silero-vad --language Japanese "https://www.youtube.com/watch?v=4cICErqqRSM"
|
45 |
+
```
|
46 |
+
|
47 |
+
Rather than supplying arguments to `app.py` or `cli.py`, you can also use the configuration file [config.json5](config.json5). See that file for more information.
|
48 |
+
If you want to use a different configuration file, you can use the `WHISPER_WEBUI_CONFIG` environment variable to specify the path to another file.
|
49 |
+
|
50 |
+
### Multiple Files
|
51 |
+
|
52 |
+
You can upload multiple files either through the "Upload files" option, or as a playlist on YouTube.
|
53 |
+
Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section.
|
54 |
+
When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
|
55 |
+
|
56 |
+
## Diarization
|
57 |
+
|
58 |
+
To detect different speakers in the audio, you can use the [whisper-diarization](https://gitlab.com/aadnk/whisper-diarization) application.
|
59 |
+
|
60 |
+
Download the JSON file after running Whisper on an audio file, and then run app.py in the
|
61 |
+
whisper-diarization repository with the audio file and the JSON file as arguments.
|
62 |
+
|
63 |
+
## Whisper Implementation
|
64 |
+
|
65 |
+
You can choose between using `whisper` or `faster-whisper`. [Faster Whisper](https://github.com/guillaumekln/faster-whisper) as a drop-in replacement for the
|
66 |
+
default Whisper which achieves up to a 4x speedup and 2x reduction in memory usage.
|
67 |
+
|
68 |
+
You can install the requirements for a specific Whisper implementation in `requirements-fasterWhisper.txt`
|
69 |
+
or `requirements-whisper.txt`:
|
70 |
+
```
|
71 |
+
pip install -r requirements-fasterWhisper.txt
|
72 |
+
```
|
73 |
+
And then run the App or the CLI with the `--whisper_implementation faster-whisper` flag:
|
74 |
+
```
|
75 |
+
python app.py --whisper_implementation faster-whisper --input_audio_max_duration -1 --server_name 127.0.0.1 --auto_parallel True
|
76 |
+
```
|
77 |
+
You can also select the whisper implementation in `config.json5`:
|
78 |
+
```json5
|
79 |
+
{
|
80 |
+
"whisper_implementation": "faster-whisper"
|
81 |
+
}
|
82 |
+
```
|
83 |
+
### GPU Acceleration
|
84 |
+
|
85 |
+
In order to use GPU acceleration with Faster Whisper, both CUDA 11.2 and cuDNN 8 must be installed. You may want to install it in a virtual environment like Anaconda.
|
86 |
+
|
87 |
+
## Google Colab
|
88 |
+
|
89 |
+
You can also run this Web UI directly on [Google Colab](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing), if you haven't got a GPU powerful enough to run the larger models.
|
90 |
+
|
91 |
+
See the [colab documentation](docs/colab.md) for more information.
|
92 |
+
|
93 |
+
## Parallel Execution
|
94 |
+
|
95 |
+
You can also run both the Web-UI or the CLI on multiple GPUs in parallel, using the `vad_parallel_devices` option. This takes a comma-delimited list of
|
96 |
+
device IDs (0, 1, etc.) that Whisper should be distributed to and run on concurrently:
|
97 |
+
```
|
98 |
+
python cli.py --model large --vad silero-vad --language Japanese \
|
99 |
+
--vad_parallel_devices 0,1 "https://www.youtube.com/watch?v=4cICErqqRSM"
|
100 |
+
```
|
101 |
+
|
102 |
+
Note that this requires a VAD to function properly, otherwise only the first GPU will be used. Though you could use `period-vad` to avoid taking the hit
|
103 |
+
of running Silero-Vad, at a slight cost to accuracy.
|
104 |
+
|
105 |
+
This is achieved by creating N child processes (where N is the number of selected devices), where Whisper is run concurrently. In `app.py`, you can also
|
106 |
+
set the `vad_process_timeout` option. This configures the number of seconds until a process is killed due to inactivity, freeing RAM and video memory.
|
107 |
+
The default value is 30 minutes.
|
108 |
+
|
109 |
+
```
|
110 |
+
python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600
|
111 |
+
```
|
112 |
+
|
113 |
+
To execute the Silero VAD itself in parallel, use the `vad_cpu_cores` option:
|
114 |
+
```
|
115 |
+
python app.py --input_audio_max_duration -1 --vad_parallel_devices 0,1 --vad_process_timeout 3600 --vad_cpu_cores 4
|
116 |
+
```
|
117 |
+
|
118 |
+
You may also use `vad_process_timeout` with a single device (`--vad_parallel_devices 0`), if you prefer to always free video memory after a period of time.
|
119 |
+
|
120 |
+
### Auto Parallel
|
121 |
+
|
122 |
+
You can also set `auto_parallel` to `True`. This will set `vad_parallel_devices` to use all the GPU devices on the system, and `vad_cpu_cores` to be equal to the number of
|
123 |
+
cores (up to 8):
|
124 |
+
```
|
125 |
+
python app.py --input_audio_max_duration -1 --auto_parallel True
|
126 |
+
```
|
127 |
+
|
128 |
+
# Docker
|
129 |
+
|
130 |
+
To run it in Docker, first install Docker and optionally the NVIDIA Container Toolkit in order to use the GPU.
|
131 |
+
Then either use the GitLab hosted container below, or check out this repository and build an image:
|
132 |
+
```
|
133 |
+
sudo docker build -t whisper-webui:1 .
|
134 |
+
```
|
135 |
+
|
136 |
+
You can then start the WebUI with GPU support like so:
|
137 |
+
```
|
138 |
+
sudo docker run -d --gpus=all -p 7860:7860 whisper-webui:1
|
139 |
+
```
|
140 |
+
|
141 |
+
Leave out "--gpus=all" if you don't have access to a GPU with enough memory, and are fine with running it on the CPU only:
|
142 |
+
```
|
143 |
+
sudo docker run -d -p 7860:7860 whisper-webui:1
|
144 |
+
```
|
145 |
+
|
146 |
+
# GitLab Docker Registry
|
147 |
+
|
148 |
+
This Docker container is also hosted on GitLab:
|
149 |
+
|
150 |
+
```
|
151 |
+
sudo docker run -d --gpus=all -p 7860:7860 registry.gitlab.com/aadnk/whisper-webui:latest
|
152 |
+
```
|
153 |
+
|
154 |
+
## Custom Arguments
|
155 |
+
|
156 |
+
You can also pass custom arguments to `app.py` in the Docker container, for instance to be able to use all the GPUs in parallel (replace administrator with your user):
|
157 |
+
```
|
158 |
+
sudo docker run -d --gpus all -p 7860:7860 \
|
159 |
+
--mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
|
160 |
+
--mount type=bind,source=/home/administrator/.cache/huggingface,target=/root/.cache/huggingface \
|
161 |
+
--restart=on-failure:15 registry.gitlab.com/aadnk/whisper-webui:latest \
|
162 |
+
app.py --input_audio_max_duration -1 --server_name 0.0.0.0 --auto_parallel True \
|
163 |
+
--default_vad silero-vad --default_model_name large
|
164 |
+
```
|
165 |
+
|
166 |
+
You can also call `cli.py` the same way:
|
167 |
+
```
|
168 |
+
sudo docker run --gpus all \
|
169 |
+
--mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
|
170 |
+
--mount type=bind,source=/home/administrator/.cache/huggingface,target=/root/.cache/huggingface \
|
171 |
+
--mount type=bind,source=${PWD},target=/app/data \
|
172 |
+
registry.gitlab.com/aadnk/whisper-webui:latest \
|
173 |
+
cli.py --model large --auto_parallel True --vad silero-vad \
|
174 |
+
--output_dir /app/data /app/data/YOUR-FILE-HERE.mp4
|
175 |
+
```
|
176 |
+
|
177 |
+
## Caching
|
178 |
+
|
179 |
+
Note that the models themselves are currently not included in the Docker images, and will be downloaded on the demand.
|
180 |
+
To avoid this, bind the directory /root/.cache/whisper to some directory on the host (for instance /home/administrator/.cache/whisper), where you can (optionally)
|
181 |
+
prepopulate the directory with the different Whisper models.
|
182 |
+
```
|
183 |
+
sudo docker run -d --gpus=all -p 7860:7860 \
|
184 |
+
--mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper \
|
185 |
+
registry.gitlab.com/aadnk/whisper-webui:latest
|
186 |
+
```
|
app-local.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Run the app with no audio file restrictions
|
2 |
+
from app import create_ui
|
3 |
+
from src.config import ApplicationConfig
|
4 |
+
|
5 |
+
create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1))
|
app-network.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Run the app with no audio file restrictions, and make it available on the network
|
2 |
+
from app import create_ui
|
3 |
+
from src.config import ApplicationConfig
|
4 |
+
|
5 |
+
create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, server_name="0.0.0.0"))
|
app-shared.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Run the app with no audio file restrictions
|
2 |
+
from app import create_ui
|
3 |
+
from src.config import ApplicationConfig
|
4 |
+
|
5 |
+
create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, share=True))
|
app.py
ADDED
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
from typing import Iterator, Union
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
from io import StringIO
|
8 |
+
import os
|
9 |
+
import pathlib
|
10 |
+
import tempfile
|
11 |
+
import zipfile
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
17 |
+
from src.hooks.progressListener import ProgressListener
|
18 |
+
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
19 |
+
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
20 |
+
from src.languages import get_language_names
|
21 |
+
from src.modelCache import ModelCache
|
22 |
+
from src.prompts.jsonPromptStrategy import JsonPromptStrategy
|
23 |
+
from src.prompts.prependPromptStrategy import PrependPromptStrategy
|
24 |
+
from src.source import get_audio_source_collection
|
25 |
+
from src.vadParallel import ParallelContext, ParallelTranscription
|
26 |
+
|
27 |
+
# External programs
|
28 |
+
import ffmpeg
|
29 |
+
|
30 |
+
# UI
|
31 |
+
import gradio as gr
|
32 |
+
|
33 |
+
from src.download import ExceededMaximumDuration, download_url
|
34 |
+
from src.utils import optional_int, slugify, write_srt, write_vtt
|
35 |
+
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
36 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
37 |
+
from src.whisper.whisperFactory import create_whisper_container
|
38 |
+
|
39 |
+
# Configure more application defaults in config.json5
|
40 |
+
|
41 |
+
# Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
|
42 |
+
MAX_FILE_PREFIX_LENGTH = 17
|
43 |
+
|
44 |
+
# Limit auto_parallel to a certain number of CPUs (specify vad_cpu_cores to get a higher number)
|
45 |
+
MAX_AUTO_CPU_CORES = 8
|
46 |
+
|
47 |
+
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
48 |
+
|
49 |
+
class VadOptions:
|
50 |
+
def __init__(self, vad: str = None, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
|
51 |
+
vadInitialPromptMode: Union[VadInitialPromptMode, str] = VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
52 |
+
self.vad = vad
|
53 |
+
self.vadMergeWindow = vadMergeWindow
|
54 |
+
self.vadMaxMergeSize = vadMaxMergeSize
|
55 |
+
self.vadPadding = vadPadding
|
56 |
+
self.vadPromptWindow = vadPromptWindow
|
57 |
+
self.vadInitialPromptMode = vadInitialPromptMode if isinstance(vadInitialPromptMode, VadInitialPromptMode) \
|
58 |
+
else VadInitialPromptMode.from_string(vadInitialPromptMode)
|
59 |
+
|
60 |
+
class WhisperTranscriber:
|
61 |
+
def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
|
62 |
+
vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
|
63 |
+
app_config: ApplicationConfig = None):
|
64 |
+
self.model_cache = ModelCache()
|
65 |
+
self.parallel_device_list = None
|
66 |
+
self.gpu_parallel_context = None
|
67 |
+
self.cpu_parallel_context = None
|
68 |
+
self.vad_process_timeout = vad_process_timeout
|
69 |
+
self.vad_cpu_cores = vad_cpu_cores
|
70 |
+
|
71 |
+
self.vad_model = None
|
72 |
+
self.inputAudioMaxDuration = input_audio_max_duration
|
73 |
+
self.deleteUploadedFiles = delete_uploaded_files
|
74 |
+
self.output_dir = output_dir
|
75 |
+
|
76 |
+
self.app_config = app_config
|
77 |
+
|
78 |
+
def set_parallel_devices(self, vad_parallel_devices: str):
|
79 |
+
self.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
|
80 |
+
|
81 |
+
def set_auto_parallel(self, auto_parallel: bool):
|
82 |
+
if auto_parallel:
|
83 |
+
if torch.cuda.is_available():
|
84 |
+
self.parallel_device_list = [ str(gpu_id) for gpu_id in range(torch.cuda.device_count())]
|
85 |
+
|
86 |
+
self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
|
87 |
+
print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
|
88 |
+
|
89 |
+
# Entry function for the simple tab
|
90 |
+
def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
91 |
+
vad, vadMergeWindow, vadMaxMergeSize,
|
92 |
+
word_timestamps: bool = False, highlight_words: bool = False):
|
93 |
+
return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
94 |
+
vad, vadMergeWindow, vadMaxMergeSize,
|
95 |
+
word_timestamps, highlight_words)
|
96 |
+
|
97 |
+
# Entry function for the simple tab progress
|
98 |
+
def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
99 |
+
vad, vadMergeWindow, vadMaxMergeSize,
|
100 |
+
word_timestamps: bool = False, highlight_words: bool = False,
|
101 |
+
progress=gr.Progress()):
|
102 |
+
|
103 |
+
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
|
104 |
+
|
105 |
+
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
106 |
+
word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
|
107 |
+
|
108 |
+
# Entry function for the full tab
|
109 |
+
def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
110 |
+
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
111 |
+
# Word timestamps
|
112 |
+
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
113 |
+
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
114 |
+
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
115 |
+
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
|
116 |
+
|
117 |
+
return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
118 |
+
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
119 |
+
word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
|
120 |
+
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
121 |
+
condition_on_previous_text, fp16, temperature_increment_on_fallback,
|
122 |
+
compression_ratio_threshold, logprob_threshold, no_speech_threshold)
|
123 |
+
|
124 |
+
# Entry function for the full tab with progress
|
125 |
+
def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
126 |
+
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
127 |
+
# Word timestamps
|
128 |
+
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
129 |
+
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
130 |
+
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
131 |
+
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
132 |
+
progress=gr.Progress()):
|
133 |
+
|
134 |
+
# Handle temperature_increment_on_fallback
|
135 |
+
if temperature_increment_on_fallback is not None:
|
136 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
137 |
+
else:
|
138 |
+
temperature = [temperature]
|
139 |
+
|
140 |
+
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
|
141 |
+
|
142 |
+
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
143 |
+
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
144 |
+
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
145 |
+
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
146 |
+
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
147 |
+
progress=progress)
|
148 |
+
|
149 |
+
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
150 |
+
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
151 |
+
**decodeOptions: dict):
|
152 |
+
try:
|
153 |
+
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
154 |
+
|
155 |
+
try:
|
156 |
+
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
157 |
+
selectedModel = modelName if modelName is not None else "base"
|
158 |
+
|
159 |
+
model = create_whisper_container(whisper_implementation=self.app_config.whisper_implementation,
|
160 |
+
model_name=selectedModel, compute_type=self.app_config.compute_type,
|
161 |
+
cache=self.model_cache, models=self.app_config.models)
|
162 |
+
|
163 |
+
# Result
|
164 |
+
download = []
|
165 |
+
zip_file_lookup = {}
|
166 |
+
text = ""
|
167 |
+
vtt = ""
|
168 |
+
|
169 |
+
# Write result
|
170 |
+
downloadDirectory = tempfile.mkdtemp()
|
171 |
+
source_index = 0
|
172 |
+
|
173 |
+
outputDirectory = self.output_dir if self.output_dir is not None else downloadDirectory
|
174 |
+
|
175 |
+
# Progress
|
176 |
+
total_duration = sum([source.get_audio_duration() for source in sources])
|
177 |
+
current_progress = 0
|
178 |
+
|
179 |
+
# A listener that will report progress to Gradio
|
180 |
+
root_progress_listener = self._create_progress_listener(progress)
|
181 |
+
|
182 |
+
# Execute whisper
|
183 |
+
for source in sources:
|
184 |
+
source_prefix = ""
|
185 |
+
source_audio_duration = source.get_audio_duration()
|
186 |
+
|
187 |
+
if (len(sources) > 1):
|
188 |
+
# Prefix (minimum 2 digits)
|
189 |
+
source_index += 1
|
190 |
+
source_prefix = str(source_index).zfill(2) + "_"
|
191 |
+
print("Transcribing ", source.source_path)
|
192 |
+
|
193 |
+
scaled_progress_listener = SubTaskProgressListener(root_progress_listener,
|
194 |
+
base_task_total=total_duration,
|
195 |
+
sub_task_start=current_progress,
|
196 |
+
sub_task_total=source_audio_duration)
|
197 |
+
|
198 |
+
# Transcribe
|
199 |
+
result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
|
200 |
+
filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
|
201 |
+
|
202 |
+
# Update progress
|
203 |
+
current_progress += source_audio_duration
|
204 |
+
|
205 |
+
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
|
206 |
+
|
207 |
+
if len(sources) > 1:
|
208 |
+
# Add new line separators
|
209 |
+
if (len(source_text) > 0):
|
210 |
+
source_text += os.linesep + os.linesep
|
211 |
+
if (len(source_vtt) > 0):
|
212 |
+
source_vtt += os.linesep + os.linesep
|
213 |
+
|
214 |
+
# Append file name to source text too
|
215 |
+
source_text = source.get_full_name() + ":" + os.linesep + source_text
|
216 |
+
source_vtt = source.get_full_name() + ":" + os.linesep + source_vtt
|
217 |
+
|
218 |
+
# Add to result
|
219 |
+
download.extend(source_download)
|
220 |
+
text += source_text
|
221 |
+
vtt += source_vtt
|
222 |
+
|
223 |
+
if (len(sources) > 1):
|
224 |
+
# Zip files support at least 260 characters, but we'll play it safe and use 200
|
225 |
+
zipFilePrefix = slugify(source_prefix + source.get_short_name(max_length=200), allow_unicode=True)
|
226 |
+
|
227 |
+
# File names in ZIP file can be longer
|
228 |
+
for source_download_file in source_download:
|
229 |
+
# Get file postfix (after last -)
|
230 |
+
filePostfix = os.path.basename(source_download_file).split("-")[-1]
|
231 |
+
zip_file_name = zipFilePrefix + "-" + filePostfix
|
232 |
+
zip_file_lookup[source_download_file] = zip_file_name
|
233 |
+
|
234 |
+
# Create zip file from all sources
|
235 |
+
if len(sources) > 1:
|
236 |
+
downloadAllPath = os.path.join(downloadDirectory, "All_Output-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
|
237 |
+
|
238 |
+
with zipfile.ZipFile(downloadAllPath, 'w', zipfile.ZIP_DEFLATED) as zip:
|
239 |
+
for download_file in download:
|
240 |
+
# Get file name from lookup
|
241 |
+
zip_file_name = zip_file_lookup.get(download_file, os.path.basename(download_file))
|
242 |
+
zip.write(download_file, arcname=zip_file_name)
|
243 |
+
|
244 |
+
download.insert(0, downloadAllPath)
|
245 |
+
|
246 |
+
return download, text, vtt
|
247 |
+
|
248 |
+
finally:
|
249 |
+
# Cleanup source
|
250 |
+
if self.deleteUploadedFiles:
|
251 |
+
for source in sources:
|
252 |
+
print("Deleting source file " + source.source_path)
|
253 |
+
|
254 |
+
try:
|
255 |
+
os.remove(source.source_path)
|
256 |
+
except Exception as e:
|
257 |
+
# Ignore error - it's just a cleanup
|
258 |
+
print("Error deleting source file " + source.source_path + ": " + str(e))
|
259 |
+
|
260 |
+
except ExceededMaximumDuration as e:
|
261 |
+
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
262 |
+
|
263 |
+
def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None,
|
264 |
+
vadOptions: VadOptions = VadOptions(),
|
265 |
+
progressListener: ProgressListener = None, **decodeOptions: dict):
|
266 |
+
|
267 |
+
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
268 |
+
|
269 |
+
if progressListener is None:
|
270 |
+
# Default progress listener
|
271 |
+
progressListener = ProgressListener()
|
272 |
+
|
273 |
+
if ('task' in decodeOptions):
|
274 |
+
task = decodeOptions.pop('task')
|
275 |
+
|
276 |
+
initial_prompt_mode = vadOptions.vadInitialPromptMode
|
277 |
+
|
278 |
+
# Set default initial prompt mode
|
279 |
+
if (initial_prompt_mode is None):
|
280 |
+
initial_prompt_mode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT
|
281 |
+
|
282 |
+
if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS or
|
283 |
+
initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
284 |
+
# Prepend initial prompt
|
285 |
+
prompt_strategy = PrependPromptStrategy(initial_prompt, initial_prompt_mode)
|
286 |
+
elif (vadOptions.vadInitialPromptMode == VadInitialPromptMode.JSON_PROMPT_MODE):
|
287 |
+
# Use a JSON format to specify the prompt for each segment
|
288 |
+
prompt_strategy = JsonPromptStrategy(initial_prompt)
|
289 |
+
else:
|
290 |
+
raise ValueError("Invalid vadInitialPromptMode: " + initial_prompt_mode)
|
291 |
+
|
292 |
+
# Callable for processing an audio file
|
293 |
+
whisperCallable = model.create_callback(language, task, prompt_strategy=prompt_strategy, **decodeOptions)
|
294 |
+
|
295 |
+
# The results
|
296 |
+
if (vadOptions.vad == 'silero-vad'):
|
297 |
+
# Silero VAD where non-speech gaps are transcribed
|
298 |
+
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadOptions)
|
299 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
|
300 |
+
elif (vadOptions.vad == 'silero-vad-skip-gaps'):
|
301 |
+
# Silero VAD where non-speech gaps are simply ignored
|
302 |
+
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadOptions)
|
303 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
|
304 |
+
elif (vadOptions.vad == 'silero-vad-expand-into-gaps'):
|
305 |
+
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
306 |
+
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadOptions)
|
307 |
+
result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
|
308 |
+
elif (vadOptions.vad == 'periodic-vad'):
|
309 |
+
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
310 |
+
# it may create a break in the middle of a sentence, causing some artifacts.
|
311 |
+
periodic_vad = VadPeriodicTranscription()
|
312 |
+
period_config = PeriodicTranscriptionConfig(periodic_duration=vadOptions.vadMaxMergeSize, max_prompt_window=vadOptions.vadPromptWindow)
|
313 |
+
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
|
314 |
+
|
315 |
+
else:
|
316 |
+
if (self._has_parallel_devices()):
|
317 |
+
# Use a simple period transcription instead, as we need to use the parallel context
|
318 |
+
periodic_vad = VadPeriodicTranscription()
|
319 |
+
period_config = PeriodicTranscriptionConfig(periodic_duration=math.inf, max_prompt_window=1)
|
320 |
+
|
321 |
+
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
|
322 |
+
else:
|
323 |
+
# Default VAD
|
324 |
+
result = whisperCallable.invoke(audio_path, 0, None, None, progress_listener=progressListener)
|
325 |
+
|
326 |
+
return result
|
327 |
+
|
328 |
+
def _create_progress_listener(self, progress: gr.Progress):
|
329 |
+
if (progress is None):
|
330 |
+
# Dummy progress listener
|
331 |
+
return ProgressListener()
|
332 |
+
|
333 |
+
class ForwardingProgressListener(ProgressListener):
|
334 |
+
def __init__(self, progress: gr.Progress):
|
335 |
+
self.progress = progress
|
336 |
+
|
337 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
338 |
+
# From 0 to 1
|
339 |
+
self.progress(current / total)
|
340 |
+
|
341 |
+
def on_finished(self):
|
342 |
+
self.progress(1)
|
343 |
+
|
344 |
+
return ForwardingProgressListener(progress)
|
345 |
+
|
346 |
+
def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig,
|
347 |
+
progressListener: ProgressListener = None):
|
348 |
+
if (not self._has_parallel_devices()):
|
349 |
+
# No parallel devices, so just run the VAD and Whisper in sequence
|
350 |
+
return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
|
351 |
+
|
352 |
+
gpu_devices = self.parallel_device_list
|
353 |
+
|
354 |
+
if (gpu_devices is None or len(gpu_devices) == 0):
|
355 |
+
# No GPU devices specified, pass the current environment variable to the first GPU process. This may be NULL.
|
356 |
+
gpu_devices = [os.environ.get("CUDA_VISIBLE_DEVICES", None)]
|
357 |
+
|
358 |
+
# Create parallel context if needed
|
359 |
+
if (self.gpu_parallel_context is None):
|
360 |
+
# Create a context wih processes and automatically clear the pool after 1 hour of inactivity
|
361 |
+
self.gpu_parallel_context = ParallelContext(num_processes=len(gpu_devices), auto_cleanup_timeout_seconds=self.vad_process_timeout)
|
362 |
+
# We also need a CPU context for the VAD
|
363 |
+
if (self.cpu_parallel_context is None):
|
364 |
+
self.cpu_parallel_context = ParallelContext(num_processes=self.vad_cpu_cores, auto_cleanup_timeout_seconds=self.vad_process_timeout)
|
365 |
+
|
366 |
+
parallel_vad = ParallelTranscription()
|
367 |
+
return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
|
368 |
+
config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
|
369 |
+
cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context,
|
370 |
+
progress_listener=progressListener)
|
371 |
+
|
372 |
+
def _has_parallel_devices(self):
|
373 |
+
return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
|
374 |
+
|
375 |
+
def _concat_prompt(self, prompt1, prompt2):
|
376 |
+
if (prompt1 is None):
|
377 |
+
return prompt2
|
378 |
+
elif (prompt2 is None):
|
379 |
+
return prompt1
|
380 |
+
else:
|
381 |
+
return prompt1 + " " + prompt2
|
382 |
+
|
383 |
+
def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
|
384 |
+
# Use Silero VAD
|
385 |
+
if (self.vad_model is None):
|
386 |
+
self.vad_model = VadSileroTranscription()
|
387 |
+
|
388 |
+
config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
|
389 |
+
max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
|
390 |
+
segment_padding_left=vadOptions.vadPadding, segment_padding_right=vadOptions.vadPadding,
|
391 |
+
max_prompt_window=vadOptions.vadPromptWindow)
|
392 |
+
|
393 |
+
return config
|
394 |
+
|
395 |
+
def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
|
396 |
+
if not os.path.exists(output_dir):
|
397 |
+
os.makedirs(output_dir)
|
398 |
+
|
399 |
+
text = result["text"]
|
400 |
+
language = result["language"]
|
401 |
+
languageMaxLineWidth = self.__get_max_line_width(language)
|
402 |
+
|
403 |
+
print("Max line width " + str(languageMaxLineWidth))
|
404 |
+
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
|
405 |
+
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
|
406 |
+
json_result = json.dumps(result, indent=4, ensure_ascii=False)
|
407 |
+
|
408 |
+
output_files = []
|
409 |
+
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
410 |
+
output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
|
411 |
+
output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
|
412 |
+
output_files.append(self.__create_file(json_result, output_dir, source_name + "-result.json"));
|
413 |
+
|
414 |
+
return output_files, text, vtt
|
415 |
+
|
416 |
+
def clear_cache(self):
|
417 |
+
self.model_cache.clear()
|
418 |
+
self.vad_model = None
|
419 |
+
|
420 |
+
def __get_source(self, urlData, multipleFiles, microphoneData):
|
421 |
+
return get_audio_source_collection(urlData, multipleFiles, microphoneData, self.inputAudioMaxDuration)
|
422 |
+
|
423 |
+
def __get_max_line_width(self, language: str) -> int:
|
424 |
+
if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
|
425 |
+
# Chinese characters and kana are wider, so limit line length to 40 characters
|
426 |
+
return 40
|
427 |
+
else:
|
428 |
+
# TODO: Add more languages
|
429 |
+
# 80 latin characters should fit on a 1080p/720p screen
|
430 |
+
return 80
|
431 |
+
|
432 |
+
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int, highlight_words: bool = False) -> str:
|
433 |
+
segmentStream = StringIO()
|
434 |
+
|
435 |
+
if format == 'vtt':
|
436 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
437 |
+
elif format == 'srt':
|
438 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
439 |
+
else:
|
440 |
+
raise Exception("Unknown format " + format)
|
441 |
+
|
442 |
+
segmentStream.seek(0)
|
443 |
+
return segmentStream.read()
|
444 |
+
|
445 |
+
def __create_file(self, text: str, directory: str, fileName: str) -> str:
|
446 |
+
# Write the text to a file
|
447 |
+
with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
|
448 |
+
file.write(text)
|
449 |
+
|
450 |
+
return file.name
|
451 |
+
|
452 |
+
def close(self):
|
453 |
+
print("Closing parallel contexts")
|
454 |
+
self.clear_cache()
|
455 |
+
|
456 |
+
if (self.gpu_parallel_context is not None):
|
457 |
+
self.gpu_parallel_context.close()
|
458 |
+
if (self.cpu_parallel_context is not None):
|
459 |
+
self.cpu_parallel_context.close()
|
460 |
+
|
461 |
+
|
462 |
+
def create_ui(app_config: ApplicationConfig):
|
463 |
+
ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
|
464 |
+
app_config.delete_uploaded_files, app_config.output_dir, app_config)
|
465 |
+
|
466 |
+
# Specify a list of devices to use for parallel processing
|
467 |
+
ui.set_parallel_devices(app_config.vad_parallel_devices)
|
468 |
+
ui.set_auto_parallel(app_config.auto_parallel)
|
469 |
+
|
470 |
+
is_whisper = False
|
471 |
+
|
472 |
+
if app_config.whisper_implementation == "whisper":
|
473 |
+
implementation_name = "Whisper"
|
474 |
+
is_whisper = True
|
475 |
+
elif app_config.whisper_implementation in ["faster-whisper", "faster_whisper"]:
|
476 |
+
implementation_name = "Faster Whisper"
|
477 |
+
else:
|
478 |
+
# Try to convert from camel-case to title-case
|
479 |
+
implementation_name = app_config.whisper_implementation.title().replace("_", " ").replace("-", " ")
|
480 |
+
|
481 |
+
ui_description = implementation_name + " is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
|
482 |
+
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
|
483 |
+
ui_description += " as well as speech translation and language identification. "
|
484 |
+
|
485 |
+
ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
|
486 |
+
|
487 |
+
# Recommend faster-whisper
|
488 |
+
if is_whisper:
|
489 |
+
ui_description += "\n\n\n\nFor faster inference on GPU, try [faster-whisper](https://huggingface.co/spaces/aadnk/faster-whisper-webui)."
|
490 |
+
|
491 |
+
if app_config.input_audio_max_duration > 0:
|
492 |
+
ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
|
493 |
+
|
494 |
+
ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
|
495 |
+
|
496 |
+
whisper_models = app_config.get_model_names()
|
497 |
+
|
498 |
+
common_inputs = lambda : [
|
499 |
+
gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
|
500 |
+
gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
|
501 |
+
gr.Text(label="URL (YouTube, etc.)"),
|
502 |
+
gr.File(label="Upload Files", file_count="multiple"),
|
503 |
+
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
504 |
+
gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
|
505 |
+
]
|
506 |
+
|
507 |
+
common_vad_inputs = lambda : [
|
508 |
+
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
|
509 |
+
gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
|
510 |
+
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
|
511 |
+
]
|
512 |
+
|
513 |
+
common_word_timestamps_inputs = lambda : [
|
514 |
+
gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps),
|
515 |
+
gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
|
516 |
+
]
|
517 |
+
|
518 |
+
is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
|
519 |
+
|
520 |
+
simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
|
521 |
+
description=ui_description, article=ui_article, inputs=[
|
522 |
+
*common_inputs(),
|
523 |
+
*common_vad_inputs(),
|
524 |
+
*common_word_timestamps_inputs(),
|
525 |
+
], outputs=[
|
526 |
+
gr.File(label="Download"),
|
527 |
+
gr.Text(label="Transcription"),
|
528 |
+
gr.Text(label="Segments")
|
529 |
+
])
|
530 |
+
|
531 |
+
full_description = ui_description + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
|
532 |
+
|
533 |
+
full_transcribe = gr.Interface(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
|
534 |
+
description=full_description, article=ui_article, inputs=[
|
535 |
+
*common_inputs(),
|
536 |
+
|
537 |
+
*common_vad_inputs(),
|
538 |
+
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
|
539 |
+
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
|
540 |
+
gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode"),
|
541 |
+
|
542 |
+
*common_word_timestamps_inputs(),
|
543 |
+
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
|
544 |
+
gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
|
545 |
+
|
546 |
+
gr.TextArea(label="Initial Prompt"),
|
547 |
+
gr.Number(label="Temperature", value=app_config.temperature),
|
548 |
+
gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
|
549 |
+
gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
|
550 |
+
gr.Number(label="Patience - Zero temperature", value=app_config.patience),
|
551 |
+
gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
|
552 |
+
gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
|
553 |
+
gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
|
554 |
+
gr.Checkbox(label="FP16", value=app_config.fp16),
|
555 |
+
gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
|
556 |
+
gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
|
557 |
+
gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
|
558 |
+
gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
|
559 |
+
], outputs=[
|
560 |
+
gr.File(label="Download"),
|
561 |
+
gr.Text(label="Transcription"),
|
562 |
+
gr.Text(label="Segments")
|
563 |
+
])
|
564 |
+
|
565 |
+
demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
|
566 |
+
|
567 |
+
# Queue up the demo
|
568 |
+
if is_queue_mode:
|
569 |
+
demo.queue(concurrency_count=app_config.queue_concurrency_count)
|
570 |
+
print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
|
571 |
+
else:
|
572 |
+
print("Queue mode disabled - progress bars will not be shown.")
|
573 |
+
|
574 |
+
demo.launch(share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
|
575 |
+
|
576 |
+
# Clean up
|
577 |
+
ui.close()
|
578 |
+
|
579 |
+
if __name__ == '__main__':
|
580 |
+
default_app_config = ApplicationConfig.create_default()
|
581 |
+
whisper_models = default_app_config.get_model_names()
|
582 |
+
|
583 |
+
# Environment variable overrides
|
584 |
+
default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
|
585 |
+
|
586 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
587 |
+
parser.add_argument("--input_audio_max_duration", type=int, default=default_app_config.input_audio_max_duration, \
|
588 |
+
help="Maximum audio file length in seconds, or -1 for no limit.") # 600
|
589 |
+
parser.add_argument("--share", type=bool, default=default_app_config.share, \
|
590 |
+
help="True to share the app on HuggingFace.") # False
|
591 |
+
parser.add_argument("--server_name", type=str, default=default_app_config.server_name, \
|
592 |
+
help="The host or IP to bind to. If None, bind to localhost.") # None
|
593 |
+
parser.add_argument("--server_port", type=int, default=default_app_config.server_port, \
|
594 |
+
help="The port to bind to.") # 7860
|
595 |
+
parser.add_argument("--queue_concurrency_count", type=int, default=default_app_config.queue_concurrency_count, \
|
596 |
+
help="The number of concurrent requests to process.") # 1
|
597 |
+
parser.add_argument("--default_model_name", type=str, choices=whisper_models, default=default_app_config.default_model_name, \
|
598 |
+
help="The default model name.") # medium
|
599 |
+
parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
|
600 |
+
help="The default VAD.") # silero-vad
|
601 |
+
parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
|
602 |
+
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
603 |
+
parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
|
604 |
+
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
605 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=default_app_config.vad_cpu_cores, \
|
606 |
+
help="The number of CPU cores to use for VAD pre-processing.") # 1
|
607 |
+
parser.add_argument("--vad_process_timeout", type=float, default=default_app_config.vad_process_timeout, \
|
608 |
+
help="The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.") # 1800
|
609 |
+
parser.add_argument("--auto_parallel", type=bool, default=default_app_config.auto_parallel, \
|
610 |
+
help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
|
611 |
+
parser.add_argument("--output_dir", "-o", type=str, default=default_app_config.output_dir, \
|
612 |
+
help="directory to save the outputs")
|
613 |
+
parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
|
614 |
+
help="the Whisper implementation to use")
|
615 |
+
parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
|
616 |
+
help="the compute type to use for inference")
|
617 |
+
parser.add_argument("--threads", type=optional_int, default=0,
|
618 |
+
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
619 |
+
|
620 |
+
args = parser.parse_args().__dict__
|
621 |
+
|
622 |
+
updated_config = default_app_config.update(**args)
|
623 |
+
|
624 |
+
if (threads := args.pop("threads")) > 0:
|
625 |
+
torch.set_num_threads(threads)
|
626 |
+
|
627 |
+
create_ui(app_config=updated_config)
|
cli.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
from urllib.parse import urlparse
|
5 |
+
import warnings
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from app import VadOptions, WhisperTranscriber
|
10 |
+
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
|
11 |
+
from src.download import download_url
|
12 |
+
from src.languages import get_language_names
|
13 |
+
|
14 |
+
from src.utils import optional_float, optional_int, str2bool
|
15 |
+
from src.whisper.whisperFactory import create_whisper_container
|
16 |
+
|
17 |
+
def cli():
|
18 |
+
app_config = ApplicationConfig.create_default()
|
19 |
+
whisper_models = app_config.get_model_names()
|
20 |
+
|
21 |
+
# For the CLI, we fallback to saving the output to the current directory
|
22 |
+
output_dir = app_config.output_dir if app_config.output_dir is not None else "."
|
23 |
+
|
24 |
+
# Environment variable overrides
|
25 |
+
default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", app_config.whisper_implementation)
|
26 |
+
|
27 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
28 |
+
parser.add_argument("audio", nargs="+", type=str, \
|
29 |
+
help="audio file(s) to transcribe")
|
30 |
+
parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
|
31 |
+
help="name of the Whisper model to use") # medium
|
32 |
+
parser.add_argument("--model_dir", type=str, default=app_config.model_dir, \
|
33 |
+
help="the path to save model files; uses ~/.cache/whisper by default")
|
34 |
+
parser.add_argument("--device", default=app_config.device, \
|
35 |
+
help="device to use for PyTorch inference")
|
36 |
+
parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
|
37 |
+
help="directory to save the outputs")
|
38 |
+
parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
|
39 |
+
help="whether to print out the progress and debug messages")
|
40 |
+
parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
|
41 |
+
help="the Whisper implementation to use")
|
42 |
+
|
43 |
+
parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
|
44 |
+
help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
45 |
+
parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(get_language_names()), \
|
46 |
+
help="language spoken in the audio, specify None to perform language detection")
|
47 |
+
|
48 |
+
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
49 |
+
help="The voice activity detection algorithm to use") # silero-vad
|
50 |
+
parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
|
51 |
+
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
52 |
+
parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
|
53 |
+
help="The window size (in seconds) to merge voice segments")
|
54 |
+
parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
|
55 |
+
help="The maximum size (in seconds) of a voice segment")
|
56 |
+
parser.add_argument("--vad_padding", type=optional_float, default=app_config.vad_padding, \
|
57 |
+
help="The padding (in seconds) to add to each voice segment")
|
58 |
+
parser.add_argument("--vad_prompt_window", type=optional_float, default=app_config.vad_prompt_window, \
|
59 |
+
help="The window size of the prompt to pass to Whisper")
|
60 |
+
parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
|
61 |
+
help="The number of CPU cores to use for VAD pre-processing.") # 1
|
62 |
+
parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
|
63 |
+
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
64 |
+
parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
|
65 |
+
help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
|
66 |
+
|
67 |
+
parser.add_argument("--temperature", type=float, default=app_config.temperature, \
|
68 |
+
help="temperature to use for sampling")
|
69 |
+
parser.add_argument("--best_of", type=optional_int, default=app_config.best_of, \
|
70 |
+
help="number of candidates when sampling with non-zero temperature")
|
71 |
+
parser.add_argument("--beam_size", type=optional_int, default=app_config.beam_size, \
|
72 |
+
help="number of beams in beam search, only applicable when temperature is zero")
|
73 |
+
parser.add_argument("--patience", type=float, default=app_config.patience, \
|
74 |
+
help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
75 |
+
parser.add_argument("--length_penalty", type=float, default=app_config.length_penalty, \
|
76 |
+
help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
|
77 |
+
|
78 |
+
parser.add_argument("--suppress_tokens", type=str, default=app_config.suppress_tokens, \
|
79 |
+
help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
80 |
+
parser.add_argument("--initial_prompt", type=str, default=app_config.initial_prompt, \
|
81 |
+
help="optional text to provide as a prompt for the first window.")
|
82 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=app_config.condition_on_previous_text, \
|
83 |
+
help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
84 |
+
parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
|
85 |
+
help="whether to perform inference in fp16; True by default")
|
86 |
+
parser.add_argument("--compute_type", type=str, default=app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
|
87 |
+
help="the compute type to use for inference")
|
88 |
+
|
89 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
|
90 |
+
help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
91 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=app_config.compression_ratio_threshold, \
|
92 |
+
help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
93 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=app_config.logprob_threshold, \
|
94 |
+
help="if the average log probability is lower than this value, treat the decoding as failed")
|
95 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
|
96 |
+
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
97 |
+
|
98 |
+
parser.add_argument("--word_timestamps", type=str2bool, default=app_config.word_timestamps,
|
99 |
+
help="(experimental) extract word-level timestamps and refine the results based on them")
|
100 |
+
parser.add_argument("--prepend_punctuations", type=str, default=app_config.prepend_punctuations,
|
101 |
+
help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
102 |
+
parser.add_argument("--append_punctuations", type=str, default=app_config.append_punctuations,
|
103 |
+
help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
104 |
+
parser.add_argument("--highlight_words", type=str2bool, default=app_config.highlight_words,
|
105 |
+
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
106 |
+
parser.add_argument("--threads", type=optional_int, default=0,
|
107 |
+
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
108 |
+
|
109 |
+
args = parser.parse_args().__dict__
|
110 |
+
model_name: str = args.pop("model")
|
111 |
+
model_dir: str = args.pop("model_dir")
|
112 |
+
output_dir: str = args.pop("output_dir")
|
113 |
+
device: str = args.pop("device")
|
114 |
+
os.makedirs(output_dir, exist_ok=True)
|
115 |
+
|
116 |
+
if (threads := args.pop("threads")) > 0:
|
117 |
+
torch.set_num_threads(threads)
|
118 |
+
|
119 |
+
whisper_implementation = args.pop("whisper_implementation")
|
120 |
+
print(f"Using {whisper_implementation} for Whisper")
|
121 |
+
|
122 |
+
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
123 |
+
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
124 |
+
args["language"] = "en"
|
125 |
+
|
126 |
+
temperature = args.pop("temperature")
|
127 |
+
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
128 |
+
if temperature_increment_on_fallback is not None:
|
129 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
130 |
+
else:
|
131 |
+
temperature = [temperature]
|
132 |
+
|
133 |
+
vad = args.pop("vad")
|
134 |
+
vad_initial_prompt_mode = args.pop("vad_initial_prompt_mode")
|
135 |
+
vad_merge_window = args.pop("vad_merge_window")
|
136 |
+
vad_max_merge_size = args.pop("vad_max_merge_size")
|
137 |
+
vad_padding = args.pop("vad_padding")
|
138 |
+
vad_prompt_window = args.pop("vad_prompt_window")
|
139 |
+
vad_cpu_cores = args.pop("vad_cpu_cores")
|
140 |
+
auto_parallel = args.pop("auto_parallel")
|
141 |
+
|
142 |
+
compute_type = args.pop("compute_type")
|
143 |
+
highlight_words = args.pop("highlight_words")
|
144 |
+
|
145 |
+
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
146 |
+
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
147 |
+
transcriber.set_auto_parallel(auto_parallel)
|
148 |
+
|
149 |
+
model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
|
150 |
+
device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
|
151 |
+
|
152 |
+
if (transcriber._has_parallel_devices()):
|
153 |
+
print("Using parallel devices:", transcriber.parallel_device_list)
|
154 |
+
|
155 |
+
for audio_path in args.pop("audio"):
|
156 |
+
sources = []
|
157 |
+
|
158 |
+
# Detect URL and download the audio
|
159 |
+
if (uri_validator(audio_path)):
|
160 |
+
# Download from YouTube/URL directly
|
161 |
+
for source_path in download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
|
162 |
+
source_name = os.path.basename(source_path)
|
163 |
+
sources.append({ "path": source_path, "name": source_name })
|
164 |
+
else:
|
165 |
+
sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
|
166 |
+
|
167 |
+
for source in sources:
|
168 |
+
source_path = source["path"]
|
169 |
+
source_name = source["name"]
|
170 |
+
|
171 |
+
vadOptions = VadOptions(vad, vad_merge_window, vad_max_merge_size, vad_padding, vad_prompt_window,
|
172 |
+
VadInitialPromptMode.from_string(vad_initial_prompt_mode))
|
173 |
+
|
174 |
+
result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
|
175 |
+
|
176 |
+
transcriber.write_result(result, source_name, output_dir, highlight_words)
|
177 |
+
|
178 |
+
transcriber.close()
|
179 |
+
|
180 |
+
def uri_validator(x):
|
181 |
+
try:
|
182 |
+
result = urlparse(x)
|
183 |
+
return all([result.scheme, result.netloc])
|
184 |
+
except:
|
185 |
+
return False
|
186 |
+
|
187 |
+
if __name__ == '__main__':
|
188 |
+
cli()
|
config.json5
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": [
|
3 |
+
// Configuration for the built-in models. You can remove any of these
|
4 |
+
// if you don't want to use the default models.
|
5 |
+
{
|
6 |
+
"name": "tiny",
|
7 |
+
"url": "tiny"
|
8 |
+
},
|
9 |
+
{
|
10 |
+
"name": "base",
|
11 |
+
"url": "base"
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"name": "small",
|
15 |
+
"url": "small"
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"name": "medium",
|
19 |
+
"url": "medium"
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"name": "large",
|
23 |
+
"url": "large"
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "large-v2",
|
27 |
+
"url": "large-v2"
|
28 |
+
},
|
29 |
+
// Uncomment to add custom Japanese models
|
30 |
+
//{
|
31 |
+
// "name": "whisper-large-v2-mix-jp",
|
32 |
+
// "url": "vumichien/whisper-large-v2-mix-jp",
|
33 |
+
// // The type of the model. Can be "huggingface" or "whisper" - "whisper" is the default.
|
34 |
+
// // HuggingFace models are loaded using the HuggingFace transformers library and then converted to Whisper models.
|
35 |
+
// "type": "huggingface",
|
36 |
+
//},
|
37 |
+
//{
|
38 |
+
// "name": "local-model",
|
39 |
+
// "url": "path/to/local/model",
|
40 |
+
//},
|
41 |
+
//{
|
42 |
+
// "name": "remote-model",
|
43 |
+
// "url": "https://example.com/path/to/model",
|
44 |
+
//}
|
45 |
+
],
|
46 |
+
// Configuration options that will be used if they are not specified in the command line arguments.
|
47 |
+
|
48 |
+
// * WEBUI options *
|
49 |
+
|
50 |
+
// Maximum audio file length in seconds, or -1 for no limit. Ignored by CLI.
|
51 |
+
"input_audio_max_duration": 600,
|
52 |
+
// True to share the app on HuggingFace.
|
53 |
+
"share": false,
|
54 |
+
// The host or IP to bind to. If None, bind to localhost.
|
55 |
+
"server_name": null,
|
56 |
+
// The port to bind to.
|
57 |
+
"server_port": 7860,
|
58 |
+
// The number of workers to use for the web server. Use -1 to disable queueing.
|
59 |
+
"queue_concurrency_count": 1,
|
60 |
+
// Whether or not to automatically delete all uploaded files, to save disk space
|
61 |
+
"delete_uploaded_files": true,
|
62 |
+
|
63 |
+
// * General options *
|
64 |
+
|
65 |
+
// The default implementation to use for Whisper. Can be "whisper" or "faster-whisper".
|
66 |
+
// Note that you must either install the requirements for faster-whisper (requirements-fasterWhisper.txt)
|
67 |
+
// or whisper (requirements.txt)
|
68 |
+
"whisper_implementation": "whisper",
|
69 |
+
|
70 |
+
// The default model name.
|
71 |
+
"default_model_name": "medium",
|
72 |
+
// The default VAD.
|
73 |
+
"default_vad": "silero-vad",
|
74 |
+
// A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.
|
75 |
+
"vad_parallel_devices": "",
|
76 |
+
// The number of CPU cores to use for VAD pre-processing.
|
77 |
+
"vad_cpu_cores": 1,
|
78 |
+
// The number of seconds before inactivate processes are terminated. Use 0 to close processes immediately, or None for no timeout.
|
79 |
+
"vad_process_timeout": 1800,
|
80 |
+
// True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
|
81 |
+
"auto_parallel": false,
|
82 |
+
// Directory to save the outputs (CLI will use the current directory if not specified)
|
83 |
+
"output_dir": null,
|
84 |
+
// The path to save model files; uses ~/.cache/whisper by default
|
85 |
+
"model_dir": null,
|
86 |
+
// Device to use for PyTorch inference, or Null to use the default device
|
87 |
+
"device": null,
|
88 |
+
// Whether to print out the progress and debug messages
|
89 |
+
"verbose": true,
|
90 |
+
// Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')
|
91 |
+
"task": "transcribe",
|
92 |
+
// Language spoken in the audio, specify None to perform language detection
|
93 |
+
"language": null,
|
94 |
+
// The window size (in seconds) to merge voice segments
|
95 |
+
"vad_merge_window": 5,
|
96 |
+
// The maximum size (in seconds) of a voice segment
|
97 |
+
"vad_max_merge_size": 30,
|
98 |
+
// The padding (in seconds) to add to each voice segment
|
99 |
+
"vad_padding": 1,
|
100 |
+
// Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)
|
101 |
+
"vad_initial_prompt_mode": "prepend_first_segment",
|
102 |
+
// The window size of the prompt to pass to Whisper
|
103 |
+
"vad_prompt_window": 3,
|
104 |
+
// Temperature to use for sampling
|
105 |
+
"temperature": 0,
|
106 |
+
// Number of candidates when sampling with non-zero temperature
|
107 |
+
"best_of": 5,
|
108 |
+
// Number of beams in beam search, only applicable when temperature is zero
|
109 |
+
"beam_size": 5,
|
110 |
+
// Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
|
111 |
+
"patience": 1,
|
112 |
+
// Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
|
113 |
+
"length_penalty": null,
|
114 |
+
// Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
|
115 |
+
"suppress_tokens": "-1",
|
116 |
+
// Optional text to provide as a prompt for the first window
|
117 |
+
"initial_prompt": null,
|
118 |
+
// If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop
|
119 |
+
"condition_on_previous_text": true,
|
120 |
+
// Whether to perform inference in fp16; True by default
|
121 |
+
"fp16": true,
|
122 |
+
// The compute type used by faster-whisper. Can be "int8". "int16" or "float16".
|
123 |
+
"compute_type": "auto",
|
124 |
+
// Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
|
125 |
+
"temperature_increment_on_fallback": 0.2,
|
126 |
+
// If the gzip compression ratio is higher than this value, treat the decoding as failed
|
127 |
+
"compression_ratio_threshold": 2.4,
|
128 |
+
// If the average log probability is lower than this value, treat the decoding as failed
|
129 |
+
"logprob_threshold": -1.0,
|
130 |
+
// If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
|
131 |
+
"no_speech_threshold": 0.6,
|
132 |
+
|
133 |
+
// (experimental) extract word-level timestamps and refine the results based on them
|
134 |
+
"word_timestamps": false,
|
135 |
+
// if word_timestamps is True, merge these punctuation symbols with the next word
|
136 |
+
"prepend_punctuations": "\"\'“¿([{-",
|
137 |
+
// if word_timestamps is True, merge these punctuation symbols with the previous word
|
138 |
+
"append_punctuations": "\"\'.。,,!!??::”)]}、",
|
139 |
+
// (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
|
140 |
+
"highlight_words": false,
|
141 |
+
}
|
dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# docker build -t whisper-webui --build-arg WHISPER_IMPLEMENTATION=whisper .
|
2 |
+
|
3 |
+
FROM huggingface/transformers-pytorch-gpu
|
4 |
+
EXPOSE 7860
|
5 |
+
|
6 |
+
ARG WHISPER_IMPLEMENTATION=whisper
|
7 |
+
ENV WHISPER_IMPLEMENTATION=${WHISPER_IMPLEMENTATION}
|
8 |
+
|
9 |
+
ADD . /opt/whisper-webui/
|
10 |
+
|
11 |
+
# Latest version of transformers-pytorch-gpu seems to lack tk.
|
12 |
+
# Further, pip install fails, so we must upgrade pip first.
|
13 |
+
RUN apt-get -y install python3-tk
|
14 |
+
RUN python3 -m pip install --upgrade pip
|
15 |
+
|
16 |
+
RUN if [ "${WHISPER_IMPLEMENTATION}" = "whisper" ]; then \
|
17 |
+
python3 -m pip install -r /opt/whisper-webui/requirements-whisper.txt; \
|
18 |
+
else \
|
19 |
+
python3 -m pip install -r /opt/whisper-webui/requirements-fasterWhisper.txt; \
|
20 |
+
fi
|
21 |
+
|
22 |
+
# Note: Models will be downloaded on demand to the directory /root/.cache/whisper.
|
23 |
+
# You can also bind this directory in the container to somewhere on the host.
|
24 |
+
|
25 |
+
# To be able to see logs in real time
|
26 |
+
ENV PYTHONUNBUFFERED=1
|
27 |
+
|
28 |
+
WORKDIR /opt/whisper-webui/
|
29 |
+
ENTRYPOINT ["python3"]
|
30 |
+
CMD ["app.py", "--input_audio_max_duration", "-1", "--server_name", "0.0.0.0", "--auto_parallel", "True"]
|
docs/colab.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Running Whisper on Google Colab
|
2 |
+
|
3 |
+
If you don't have a decent GPU or any experience in running command-line applications, you might want to try this Google Colab instead:
|
4 |
+
|
5 |
+
* [Google Colab - Whisper WebUI GPU](https://colab.research.google.com/drive/1qeTSvi7Bt_5RMm88ipW4fkcsMOKlDDss?usp=sharing)
|
6 |
+
* [Screenshots](https://imgur.com/a/ZfY6uBO)
|
7 |
+
|
8 |
+
The runtime (Runtime -> Change runtime type -> Hardware accelerator) should already be set top GPU. But if not, change it to GPU.
|
9 |
+
|
10 |
+
Then, sign in to Google if you haven't already. Next, click on "Connect" at the top right.
|
11 |
+
|
12 |
+
Under "Checking out WebUI from Git", click on the [play icon](https://imgur.com/a/81gOLyD) that appears in "[ ]" at the left. If you get a warning, click "Run anyway".
|
13 |
+
|
14 |
+
After this step has completed, it should be get a green check mark. Then move on to the next section under "Installing dependencies", and click in "[ ]" again. This might take approximately 30 seconds.
|
15 |
+
|
16 |
+
Once this has completed, scroll down to the "Run WebUI" section, and click on "[ ]". This will launch the WebUI in a shared link (expires in 72 hours). To open the UI, click on the link next to "Running on public URL", which will be something like https://12xxx.gradio.app/
|
17 |
+
|
18 |
+
The audio length in this version is not restricted, and it will run much faster as it is backed by a GPU. You can also run it using the "Large" model. Also note that it might take some time to start the model the first time, as it may need to download a 2.8 GB file on Google's servers.
|
19 |
+
|
20 |
+
Once you're done, you can close the WebUI session by clicking the animated close button under "Run WebUI". You can also do this if you encounter any errors and need to restart the UI. You should also go to "Manage Sessions" and terminate the session, otherwise you may end up using all your free compute credits.
|
docs/options.md
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Standard Options
|
2 |
+
To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
|
3 |
+
supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
|
4 |
+
in the file selector to select any file type, including video files) or use the microphone.
|
5 |
+
|
6 |
+
For longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option, especially if you are using the `large-v1` model. Note that `large-v2` is a lot more forgiving, but you may still want to use a VAD with a slightly higher "VAD - Max Merge Size (s)" (60 seconds or more).
|
7 |
+
|
8 |
+
## Model
|
9 |
+
Select the model that Whisper will use to transcribe the audio:
|
10 |
+
|
11 |
+
| Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
|
12 |
+
|-----------|------------|--------------------|--------------------|---------------|----------------|
|
13 |
+
| tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x |
|
14 |
+
| base | 74 M | base.en | base | ~1 GB | ~16x |
|
15 |
+
| small | 244 M | small.en | small | ~2 GB | ~6x |
|
16 |
+
| medium | 769 M | medium.en | medium | ~5 GB | ~2x |
|
17 |
+
| large | 1550 M | N/A | large | ~10 GB | 1x |
|
18 |
+
| large-v2 | 1550 M | N/A | large | ~10 GB | 1x |
|
19 |
+
|
20 |
+
## Language
|
21 |
+
|
22 |
+
Select the language, or leave it empty for Whisper to automatically detect it.
|
23 |
+
|
24 |
+
Note that if the selected language and the language in the audio differs, Whisper may start to translate the audio to the selected
|
25 |
+
language. For instance, if the audio is in English but you select Japaneese, the model may translate the audio to Japanese.
|
26 |
+
|
27 |
+
## Inputs
|
28 |
+
The options "URL (YouTube, etc.)", "Upload Files" or "Micriphone Input" allows you to send an audio input to the model.
|
29 |
+
|
30 |
+
### Multiple Files
|
31 |
+
Note that the UI will only process either the given URL or the upload files (including microphone) - not both.
|
32 |
+
|
33 |
+
But you can upload multiple files either through the "Upload files" option, or as a playlist on YouTube. Each audio file will then be processed in turn, and the resulting SRT/VTT/Transcript will be made available in the "Download" section. When more than one file is processed, the UI will also generate a "All_Output" zip file containing all the text output files.
|
34 |
+
|
35 |
+
## Task
|
36 |
+
Select the task - either "transcribe" to transcribe the audio to text, or "translate" to translate it to English.
|
37 |
+
|
38 |
+
## Vad
|
39 |
+
Using a VAD will improve the timing accuracy of each transcribed line, as well as prevent Whisper getting into an infinite
|
40 |
+
loop detecting the same sentence over and over again. The downside is that this may be at a cost to text accuracy, especially
|
41 |
+
with regards to unique words or names that appear in the audio. You can compensate for this by increasing the prompt window.
|
42 |
+
|
43 |
+
Note that English is very well handled by Whisper, and it's less susceptible to issues surrounding bad timings and infinite loops.
|
44 |
+
So you may only need to use a VAD for other languages, such as Japanese, or when the audio is very long.
|
45 |
+
|
46 |
+
* none
|
47 |
+
* Run whisper on the entire audio input
|
48 |
+
* silero-vad
|
49 |
+
* Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Whisper is also run
|
50 |
+
on the gaps between each speech section, by either expanding the section up to the max merge size, or running Whisper independently
|
51 |
+
on the non-speech section.
|
52 |
+
* silero-vad-expand-into-gaps
|
53 |
+
* Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Each spech section will be expanded
|
54 |
+
such that they cover any adjacent non-speech sections. For instance, if an audio file of one minute contains the speech sections
|
55 |
+
00:00 - 00:10 (A) and 00:30 - 00:40 (B), the first section (A) will be expanded to 00:00 - 00:30, and (B) will be expanded to 00:30 - 00:60.
|
56 |
+
* silero-vad-skip-gaps
|
57 |
+
* As above, but sections that doesn't contain speech according to Silero will be skipped. This will be slightly faster, but
|
58 |
+
may cause dialogue to be skipped.
|
59 |
+
* periodic-vad
|
60 |
+
* Create sections of speech every 'VAD - Max Merge Size' seconds. This is very fast and simple, but will potentially break
|
61 |
+
a sentence or word in two.
|
62 |
+
|
63 |
+
## VAD - Merge Window
|
64 |
+
If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged.
|
65 |
+
|
66 |
+
## VAD - Max Merge Size (s)
|
67 |
+
Disables merging of adjacent speech sections if they are this number of seconds long.
|
68 |
+
|
69 |
+
## VAD - Padding (s)
|
70 |
+
The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number
|
71 |
+
larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of
|
72 |
+
a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp
|
73 |
+
to each transcribed line. The default value is 1 second.
|
74 |
+
|
75 |
+
## VAD - Prompt Window (s)
|
76 |
+
The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this
|
77 |
+
number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at
|
78 |
+
10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds).
|
79 |
+
|
80 |
+
Note that detected lines in gaps between speech sections will not be included in the prompt
|
81 |
+
(if silero-vad or silero-vad-expand-into-gaps) is used.
|
82 |
+
|
83 |
+
# Command Line Options
|
84 |
+
|
85 |
+
Both `app.py` and `cli.py` also accept command line options, such as the ability to enable parallel execution on multiple
|
86 |
+
CPU/GPU cores, the default model name/VAD and so on. Consult the README in the root folder for more information.
|
87 |
+
|
88 |
+
# Additional Options
|
89 |
+
|
90 |
+
In addition to the above, there's also a "Full" options interface that allows you to set all the options available in the Whisper
|
91 |
+
model. The options are as follows:
|
92 |
+
|
93 |
+
## Initial Prompt
|
94 |
+
Optional text to provide as a prompt for the first 30 seconds window. Whisper will attempt to use this as a starting point for the transcription, but you can
|
95 |
+
also get creative and specify a style or format for the output of the transcription.
|
96 |
+
|
97 |
+
For instance, if you use the prompt "hello how is it going always use lowercase no punctuation goodbye one two three start stop i you me they", Whisper will
|
98 |
+
be biased to output lower capital letters and no punctuation, and may also be biased to output the words in the prompt more often.
|
99 |
+
|
100 |
+
## Temperature
|
101 |
+
The temperature to use when sampling. Default is 0 (zero). A higher temperature will result in more random output, while a lower temperature will be more deterministic.
|
102 |
+
|
103 |
+
## Best Of - Non-zero temperature
|
104 |
+
The number of candidates to sample from when sampling with non-zero temperature. Default is 5.
|
105 |
+
|
106 |
+
## Beam Size - Zero temperature
|
107 |
+
The number of beams to use in beam search when sampling with zero temperature. Default is 5.
|
108 |
+
|
109 |
+
## Patience - Zero temperature
|
110 |
+
The patience value to use in beam search when sampling with zero temperature. As in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search.
|
111 |
+
|
112 |
+
## Length Penalty - Any temperature
|
113 |
+
The token length penalty coefficient (alpha) to use when sampling with any temperature. As in https://arxiv.org/abs/1609.08144, uses simple length normalization by default.
|
114 |
+
|
115 |
+
## Suppress Tokens - Comma-separated list of token IDs
|
116 |
+
A comma-separated list of token IDs to suppress during sampling. The default value of "-1" will suppress most special characters except common punctuations.
|
117 |
+
|
118 |
+
## Condition on previous text
|
119 |
+
If True, provide the previous output of the model as a prompt for the next window. Disabling this may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop.
|
120 |
+
|
121 |
+
## FP16
|
122 |
+
Whether to perform inference in fp16. True by default.
|
123 |
+
|
124 |
+
## Temperature increment on fallback
|
125 |
+
The temperature to increase when falling back when the decoding fails to meet either of the thresholds below. Default is 0.2.
|
126 |
+
|
127 |
+
## Compression ratio threshold
|
128 |
+
If the gzip compression ratio is higher than this value, treat the decoding as failed. Default is 2.4.
|
129 |
+
|
130 |
+
## Logprob threshold
|
131 |
+
If the average log probability is lower than this value, treat the decoding as failed. Default is -1.0.
|
132 |
+
|
133 |
+
## No speech threshold
|
134 |
+
If the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence. Default is 0.6.
|
docs/windows/install_win10_win11.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9b9f4ed547d6534411c17da1ea56707d2ec6e812611b1cbd3098756d5cbb8084
|
3 |
+
size 3378789
|
requirements-fasterWhisper.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ctranslate2
|
2 |
+
faster-whisper
|
3 |
+
ffmpeg-python==0.2.0
|
4 |
+
gradio==3.23.0
|
5 |
+
yt-dlp
|
6 |
+
json5
|
7 |
+
torch
|
8 |
+
torchaudio
|
9 |
+
more_itertools
|
requirements-whisper.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/transformers
|
2 |
+
git+https://github.com/openai/whisper.git
|
3 |
+
transformers
|
4 |
+
ffmpeg-python==0.2.0
|
5 |
+
gradio==3.23.0
|
6 |
+
yt-dlp
|
7 |
+
torchaudio
|
8 |
+
altair
|
9 |
+
json5
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/transformers
|
2 |
+
git+https://github.com/openai/whisper.git
|
3 |
+
transformers
|
4 |
+
ffmpeg-python==0.2.0
|
5 |
+
gradio==3.23.0
|
6 |
+
yt-dlp
|
7 |
+
torchaudio
|
8 |
+
altair
|
9 |
+
json5
|
src/__init__.py
ADDED
File without changes
|
src/config.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import urllib
|
3 |
+
|
4 |
+
import os
|
5 |
+
from typing import List
|
6 |
+
from urllib.parse import urlparse
|
7 |
+
import json5
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
class ModelConfig:
|
13 |
+
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
|
14 |
+
"""
|
15 |
+
Initialize a model configuration.
|
16 |
+
|
17 |
+
name: Name of the model
|
18 |
+
url: URL to download the model from
|
19 |
+
path: Path to the model file. If not set, the model will be downloaded from the URL.
|
20 |
+
type: Type of model. Can be whisper or huggingface.
|
21 |
+
"""
|
22 |
+
self.name = name
|
23 |
+
self.url = url
|
24 |
+
self.path = path
|
25 |
+
self.type = type
|
26 |
+
|
27 |
+
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
28 |
+
|
29 |
+
class VadInitialPromptMode(Enum):
|
30 |
+
PREPEND_ALL_SEGMENTS = 1
|
31 |
+
PREPREND_FIRST_SEGMENT = 2
|
32 |
+
JSON_PROMPT_MODE = 3
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def from_string(s: str):
|
36 |
+
normalized = s.lower() if s is not None else None
|
37 |
+
|
38 |
+
if normalized == "prepend_all_segments":
|
39 |
+
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
|
40 |
+
elif normalized == "prepend_first_segment":
|
41 |
+
return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
|
42 |
+
elif normalized == "json_prompt_mode":
|
43 |
+
return VadInitialPromptMode.JSON_PROMPT_MODE
|
44 |
+
elif normalized is not None and normalized != "":
|
45 |
+
raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
|
46 |
+
else:
|
47 |
+
return None
|
48 |
+
|
49 |
+
class ApplicationConfig:
|
50 |
+
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
|
51 |
+
share: bool = False, server_name: str = None, server_port: int = 7860,
|
52 |
+
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
53 |
+
whisper_implementation: str = "whisper",
|
54 |
+
default_model_name: str = "medium", default_vad: str = "silero-vad",
|
55 |
+
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
56 |
+
auto_parallel: bool = False, output_dir: str = None,
|
57 |
+
model_dir: str = None, device: str = None,
|
58 |
+
verbose: bool = True, task: str = "transcribe", language: str = None,
|
59 |
+
vad_initial_prompt_mode: str = "prepend_first_segment ",
|
60 |
+
vad_merge_window: float = 5, vad_max_merge_size: float = 30,
|
61 |
+
vad_padding: float = 1, vad_prompt_window: float = 3,
|
62 |
+
temperature: float = 0, best_of: int = 5, beam_size: int = 5,
|
63 |
+
patience: float = None, length_penalty: float = None,
|
64 |
+
suppress_tokens: str = "-1", initial_prompt: str = None,
|
65 |
+
condition_on_previous_text: bool = True, fp16: bool = True,
|
66 |
+
compute_type: str = "float16",
|
67 |
+
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
68 |
+
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
|
69 |
+
# Word timestamp settings
|
70 |
+
word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
|
71 |
+
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
72 |
+
highlight_words: bool = False):
|
73 |
+
|
74 |
+
self.models = models
|
75 |
+
|
76 |
+
# WebUI settings
|
77 |
+
self.input_audio_max_duration = input_audio_max_duration
|
78 |
+
self.share = share
|
79 |
+
self.server_name = server_name
|
80 |
+
self.server_port = server_port
|
81 |
+
self.queue_concurrency_count = queue_concurrency_count
|
82 |
+
self.delete_uploaded_files = delete_uploaded_files
|
83 |
+
|
84 |
+
self.whisper_implementation = whisper_implementation
|
85 |
+
self.default_model_name = default_model_name
|
86 |
+
self.default_vad = default_vad
|
87 |
+
self.vad_parallel_devices = vad_parallel_devices
|
88 |
+
self.vad_cpu_cores = vad_cpu_cores
|
89 |
+
self.vad_process_timeout = vad_process_timeout
|
90 |
+
self.auto_parallel = auto_parallel
|
91 |
+
self.output_dir = output_dir
|
92 |
+
|
93 |
+
self.model_dir = model_dir
|
94 |
+
self.device = device
|
95 |
+
self.verbose = verbose
|
96 |
+
self.task = task
|
97 |
+
self.language = language
|
98 |
+
self.vad_initial_prompt_mode = vad_initial_prompt_mode
|
99 |
+
self.vad_merge_window = vad_merge_window
|
100 |
+
self.vad_max_merge_size = vad_max_merge_size
|
101 |
+
self.vad_padding = vad_padding
|
102 |
+
self.vad_prompt_window = vad_prompt_window
|
103 |
+
self.temperature = temperature
|
104 |
+
self.best_of = best_of
|
105 |
+
self.beam_size = beam_size
|
106 |
+
self.patience = patience
|
107 |
+
self.length_penalty = length_penalty
|
108 |
+
self.suppress_tokens = suppress_tokens
|
109 |
+
self.initial_prompt = initial_prompt
|
110 |
+
self.condition_on_previous_text = condition_on_previous_text
|
111 |
+
self.fp16 = fp16
|
112 |
+
self.compute_type = compute_type
|
113 |
+
self.temperature_increment_on_fallback = temperature_increment_on_fallback
|
114 |
+
self.compression_ratio_threshold = compression_ratio_threshold
|
115 |
+
self.logprob_threshold = logprob_threshold
|
116 |
+
self.no_speech_threshold = no_speech_threshold
|
117 |
+
|
118 |
+
# Word timestamp settings
|
119 |
+
self.word_timestamps = word_timestamps
|
120 |
+
self.prepend_punctuations = prepend_punctuations
|
121 |
+
self.append_punctuations = append_punctuations
|
122 |
+
self.highlight_words = highlight_words
|
123 |
+
|
124 |
+
def get_model_names(self):
|
125 |
+
return [ x.name for x in self.models ]
|
126 |
+
|
127 |
+
def update(self, **new_values):
|
128 |
+
result = ApplicationConfig(**self.__dict__)
|
129 |
+
|
130 |
+
for key, value in new_values.items():
|
131 |
+
setattr(result, key, value)
|
132 |
+
return result
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def create_default(**kwargs):
|
136 |
+
app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
|
137 |
+
|
138 |
+
# Update with kwargs
|
139 |
+
if len(kwargs) > 0:
|
140 |
+
app_config = app_config.update(**kwargs)
|
141 |
+
return app_config
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def parse_file(config_path: str):
|
145 |
+
import json5
|
146 |
+
|
147 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
148 |
+
# Load using json5
|
149 |
+
data = json5.load(f)
|
150 |
+
data_models = data.pop("models", [])
|
151 |
+
|
152 |
+
models = [ ModelConfig(**x) for x in data_models ]
|
153 |
+
|
154 |
+
return ApplicationConfig(models, **data)
|
src/conversion/hf_converter.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/bayartsogt-ya/whisper-multiple-hf-datasets
|
2 |
+
|
3 |
+
from copy import deepcopy
|
4 |
+
import torch
|
5 |
+
|
6 |
+
WHISPER_MAPPING = {
|
7 |
+
"layers": "blocks",
|
8 |
+
"fc1": "mlp.0",
|
9 |
+
"fc2": "mlp.2",
|
10 |
+
"final_layer_norm": "mlp_ln",
|
11 |
+
"layers": "blocks",
|
12 |
+
".self_attn.q_proj": ".attn.query",
|
13 |
+
".self_attn.k_proj": ".attn.key",
|
14 |
+
".self_attn.v_proj": ".attn.value",
|
15 |
+
".self_attn_layer_norm": ".attn_ln",
|
16 |
+
".self_attn.out_proj": ".attn.out",
|
17 |
+
".encoder_attn.q_proj": ".cross_attn.query",
|
18 |
+
".encoder_attn.k_proj": ".cross_attn.key",
|
19 |
+
".encoder_attn.v_proj": ".cross_attn.value",
|
20 |
+
".encoder_attn_layer_norm": ".cross_attn_ln",
|
21 |
+
".encoder_attn.out_proj": ".cross_attn.out",
|
22 |
+
"decoder.layer_norm.": "decoder.ln.",
|
23 |
+
"encoder.layer_norm.": "encoder.ln_post.",
|
24 |
+
"embed_tokens": "token_embedding",
|
25 |
+
"encoder.embed_positions.weight": "encoder.positional_embedding",
|
26 |
+
"decoder.embed_positions.weight": "decoder.positional_embedding",
|
27 |
+
"layer_norm": "ln_post",
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def rename_keys(s_dict):
|
32 |
+
keys = list(s_dict.keys())
|
33 |
+
for key in keys:
|
34 |
+
new_key = key
|
35 |
+
for k, v in WHISPER_MAPPING.items():
|
36 |
+
if k in key:
|
37 |
+
new_key = new_key.replace(k, v)
|
38 |
+
|
39 |
+
print(f"{key} -> {new_key}")
|
40 |
+
|
41 |
+
s_dict[new_key] = s_dict.pop(key)
|
42 |
+
return s_dict
|
43 |
+
|
44 |
+
|
45 |
+
def convert_hf_whisper(hf_model_name_or_path: str, whisper_state_path: str):
|
46 |
+
from transformers import WhisperForConditionalGeneration
|
47 |
+
transformer_model = WhisperForConditionalGeneration.from_pretrained(hf_model_name_or_path)
|
48 |
+
config = transformer_model.config
|
49 |
+
|
50 |
+
# first build dims
|
51 |
+
dims = {
|
52 |
+
'n_mels': config.num_mel_bins,
|
53 |
+
'n_vocab': config.vocab_size,
|
54 |
+
'n_audio_ctx': config.max_source_positions,
|
55 |
+
'n_audio_state': config.d_model,
|
56 |
+
'n_audio_head': config.encoder_attention_heads,
|
57 |
+
'n_audio_layer': config.encoder_layers,
|
58 |
+
'n_text_ctx': config.max_target_positions,
|
59 |
+
'n_text_state': config.d_model,
|
60 |
+
'n_text_head': config.decoder_attention_heads,
|
61 |
+
'n_text_layer': config.decoder_layers
|
62 |
+
}
|
63 |
+
|
64 |
+
state_dict = deepcopy(transformer_model.model.state_dict())
|
65 |
+
state_dict = rename_keys(state_dict)
|
66 |
+
|
67 |
+
torch.save({"dims": dims, "model_state_dict": state_dict}, whisper_state_path)
|
src/download.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tempfile import mkdtemp
|
2 |
+
from typing import List
|
3 |
+
from yt_dlp import YoutubeDL
|
4 |
+
|
5 |
+
import yt_dlp
|
6 |
+
from yt_dlp.postprocessor import PostProcessor
|
7 |
+
|
8 |
+
class FilenameCollectorPP(PostProcessor):
|
9 |
+
def __init__(self):
|
10 |
+
super(FilenameCollectorPP, self).__init__(None)
|
11 |
+
self.filenames = []
|
12 |
+
|
13 |
+
def run(self, information):
|
14 |
+
self.filenames.append(information["filepath"])
|
15 |
+
return [], information
|
16 |
+
|
17 |
+
def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
|
18 |
+
try:
|
19 |
+
return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
|
20 |
+
except yt_dlp.utils.DownloadError as e:
|
21 |
+
# In case of an OS error, try again with a different output template
|
22 |
+
if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
|
23 |
+
return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
|
24 |
+
pass
|
25 |
+
|
26 |
+
def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
|
27 |
+
# Create a temporary directory to store the downloaded files
|
28 |
+
if destinationDirectory is None:
|
29 |
+
destinationDirectory = mkdtemp()
|
30 |
+
|
31 |
+
ydl_opts = {
|
32 |
+
"format": "bestaudio/best",
|
33 |
+
'paths': {
|
34 |
+
'home': destinationDirectory
|
35 |
+
}
|
36 |
+
}
|
37 |
+
if (playlistItems):
|
38 |
+
ydl_opts['playlist_items'] = playlistItems
|
39 |
+
|
40 |
+
# Add output template if specified
|
41 |
+
if outputTemplate:
|
42 |
+
ydl_opts['outtmpl'] = outputTemplate
|
43 |
+
|
44 |
+
filename_collector = FilenameCollectorPP()
|
45 |
+
|
46 |
+
with YoutubeDL(ydl_opts) as ydl:
|
47 |
+
if maxDuration and maxDuration > 0:
|
48 |
+
info = ydl.extract_info(url, download=False)
|
49 |
+
entries = "entries" in info and info["entries"] or [info]
|
50 |
+
|
51 |
+
total_duration = 0
|
52 |
+
|
53 |
+
# Compute total duration
|
54 |
+
for entry in entries:
|
55 |
+
total_duration += float(entry["duration"])
|
56 |
+
|
57 |
+
if total_duration >= maxDuration:
|
58 |
+
raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=maxDuration, message="Video is too long")
|
59 |
+
|
60 |
+
ydl.add_post_processor(filename_collector)
|
61 |
+
ydl.download([url])
|
62 |
+
|
63 |
+
if len(filename_collector.filenames) <= 0:
|
64 |
+
raise Exception("Cannot download " + url)
|
65 |
+
|
66 |
+
result = []
|
67 |
+
|
68 |
+
for filename in filename_collector.filenames:
|
69 |
+
result.append(filename)
|
70 |
+
print("Downloaded " + filename)
|
71 |
+
|
72 |
+
return result
|
73 |
+
|
74 |
+
class ExceededMaximumDuration(Exception):
|
75 |
+
def __init__(self, videoDuration, maxDuration, message):
|
76 |
+
self.videoDuration = videoDuration
|
77 |
+
self.maxDuration = maxDuration
|
78 |
+
super().__init__(message)
|
src/hooks/progressListener.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
class ProgressListener:
|
4 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
5 |
+
self.total = total
|
6 |
+
|
7 |
+
def on_finished(self):
|
8 |
+
pass
|
src/hooks/subTaskProgressListener.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.hooks.progressListener import ProgressListener
|
2 |
+
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
class SubTaskProgressListener(ProgressListener):
|
6 |
+
"""
|
7 |
+
A sub task listener that reports the progress of a sub task to a base task listener
|
8 |
+
Parameters
|
9 |
+
----------
|
10 |
+
base_task_listener : ProgressListener
|
11 |
+
The base progress listener to accumulate overall progress in.
|
12 |
+
base_task_total : float
|
13 |
+
The maximum total progress that will be reported to the base progress listener.
|
14 |
+
sub_task_start : float
|
15 |
+
The starting progress of a sub task, in respect to the base progress listener.
|
16 |
+
sub_task_total : float
|
17 |
+
The total amount of progress a sub task will report to the base progress listener.
|
18 |
+
"""
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
base_task_listener: ProgressListener,
|
22 |
+
base_task_total: float,
|
23 |
+
sub_task_start: float,
|
24 |
+
sub_task_total: float,
|
25 |
+
):
|
26 |
+
self.base_task_listener = base_task_listener
|
27 |
+
self.base_task_total = base_task_total
|
28 |
+
self.sub_task_start = sub_task_start
|
29 |
+
self.sub_task_total = sub_task_total
|
30 |
+
|
31 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
32 |
+
sub_task_progress_frac = current / total
|
33 |
+
sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac
|
34 |
+
self.base_task_listener.on_progress(sub_task_progress, self.base_task_total)
|
35 |
+
|
36 |
+
def on_finished(self):
|
37 |
+
self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total)
|
src/hooks/whisperProgressHook.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import threading
|
3 |
+
from typing import List, Union
|
4 |
+
import tqdm
|
5 |
+
|
6 |
+
from src.hooks.progressListener import ProgressListener
|
7 |
+
|
8 |
+
class ProgressListenerHandle:
|
9 |
+
def __init__(self, listener: ProgressListener):
|
10 |
+
self.listener = listener
|
11 |
+
|
12 |
+
def __enter__(self):
|
13 |
+
register_thread_local_progress_listener(self.listener)
|
14 |
+
|
15 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
16 |
+
unregister_thread_local_progress_listener(self.listener)
|
17 |
+
|
18 |
+
if exc_type is None:
|
19 |
+
self.listener.on_finished()
|
20 |
+
|
21 |
+
class _CustomProgressBar(tqdm.tqdm):
|
22 |
+
def __init__(self, *args, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
self._current = self.n # Set the initial value
|
25 |
+
|
26 |
+
def update(self, n):
|
27 |
+
super().update(n)
|
28 |
+
# Because the progress bar might be disabled, we need to manually update the progress
|
29 |
+
self._current += n
|
30 |
+
|
31 |
+
# Inform listeners
|
32 |
+
listeners = _get_thread_local_listeners()
|
33 |
+
|
34 |
+
for listener in listeners:
|
35 |
+
listener.on_progress(self._current, self.total)
|
36 |
+
|
37 |
+
_thread_local = threading.local()
|
38 |
+
|
39 |
+
def _get_thread_local_listeners():
|
40 |
+
if not hasattr(_thread_local, 'listeners'):
|
41 |
+
_thread_local.listeners = []
|
42 |
+
return _thread_local.listeners
|
43 |
+
|
44 |
+
_hooked = False
|
45 |
+
|
46 |
+
def init_progress_hook():
|
47 |
+
global _hooked
|
48 |
+
|
49 |
+
if _hooked:
|
50 |
+
return
|
51 |
+
|
52 |
+
# Inject into tqdm.tqdm of Whisper, so we can see progress
|
53 |
+
import whisper.transcribe
|
54 |
+
transcribe_module = sys.modules['whisper.transcribe']
|
55 |
+
transcribe_module.tqdm.tqdm = _CustomProgressBar
|
56 |
+
_hooked = True
|
57 |
+
|
58 |
+
def register_thread_local_progress_listener(progress_listener: ProgressListener):
|
59 |
+
# This is a workaround for the fact that the progress bar is not exposed in the API
|
60 |
+
init_progress_hook()
|
61 |
+
|
62 |
+
listeners = _get_thread_local_listeners()
|
63 |
+
listeners.append(progress_listener)
|
64 |
+
|
65 |
+
def unregister_thread_local_progress_listener(progress_listener: ProgressListener):
|
66 |
+
listeners = _get_thread_local_listeners()
|
67 |
+
|
68 |
+
if progress_listener in listeners:
|
69 |
+
listeners.remove(progress_listener)
|
70 |
+
|
71 |
+
def create_progress_listener_handle(progress_listener: ProgressListener):
|
72 |
+
return ProgressListenerHandle(progress_listener)
|
73 |
+
|
74 |
+
# Example usage
|
75 |
+
if __name__ == '__main__':
|
76 |
+
class PrintingProgressListener:
|
77 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
78 |
+
print(f"Progress: {current}/{total}")
|
79 |
+
|
80 |
+
def on_finished(self):
|
81 |
+
print("Finished")
|
82 |
+
|
83 |
+
import whisper
|
84 |
+
model = whisper.load_model("medium")
|
85 |
+
|
86 |
+
with create_progress_listener_handle(PrintingProgressListener()) as listener:
|
87 |
+
# Set verbose to None to disable the progress bar, as we are using our own
|
88 |
+
result = model.transcribe("J:\\Dev\\OpenAI\\whisper\\tests\\Noriko\\out.mka", language="Japanese", fp16=False, verbose=None)
|
89 |
+
print(result)
|
90 |
+
|
91 |
+
print("Done")
|
src/languages.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Language():
|
2 |
+
def __init__(self, code, name):
|
3 |
+
self.code = code
|
4 |
+
self.name = name
|
5 |
+
|
6 |
+
def __str__(self):
|
7 |
+
return "Language(code={}, name={})".format(self.code, self.name)
|
8 |
+
|
9 |
+
LANGUAGES = [
|
10 |
+
Language('en', 'English'),
|
11 |
+
Language('zh', 'Chinese'),
|
12 |
+
Language('de', 'German'),
|
13 |
+
Language('es', 'Spanish'),
|
14 |
+
Language('ru', 'Russian'),
|
15 |
+
Language('ko', 'Korean'),
|
16 |
+
Language('fr', 'French'),
|
17 |
+
Language('ja', 'Japanese'),
|
18 |
+
Language('pt', 'Portuguese'),
|
19 |
+
Language('tr', 'Turkish'),
|
20 |
+
Language('pl', 'Polish'),
|
21 |
+
Language('ca', 'Catalan'),
|
22 |
+
Language('nl', 'Dutch'),
|
23 |
+
Language('ar', 'Arabic'),
|
24 |
+
Language('sv', 'Swedish'),
|
25 |
+
Language('it', 'Italian'),
|
26 |
+
Language('id', 'Indonesian'),
|
27 |
+
Language('hi', 'Hindi'),
|
28 |
+
Language('fi', 'Finnish'),
|
29 |
+
Language('vi', 'Vietnamese'),
|
30 |
+
Language('he', 'Hebrew'),
|
31 |
+
Language('uk', 'Ukrainian'),
|
32 |
+
Language('el', 'Greek'),
|
33 |
+
Language('ms', 'Malay'),
|
34 |
+
Language('cs', 'Czech'),
|
35 |
+
Language('ro', 'Romanian'),
|
36 |
+
Language('da', 'Danish'),
|
37 |
+
Language('hu', 'Hungarian'),
|
38 |
+
Language('ta', 'Tamil'),
|
39 |
+
Language('no', 'Norwegian'),
|
40 |
+
Language('th', 'Thai'),
|
41 |
+
Language('ur', 'Urdu'),
|
42 |
+
Language('hr', 'Croatian'),
|
43 |
+
Language('bg', 'Bulgarian'),
|
44 |
+
Language('lt', 'Lithuanian'),
|
45 |
+
Language('la', 'Latin'),
|
46 |
+
Language('mi', 'Maori'),
|
47 |
+
Language('ml', 'Malayalam'),
|
48 |
+
Language('cy', 'Welsh'),
|
49 |
+
Language('sk', 'Slovak'),
|
50 |
+
Language('te', 'Telugu'),
|
51 |
+
Language('fa', 'Persian'),
|
52 |
+
Language('lv', 'Latvian'),
|
53 |
+
Language('bn', 'Bengali'),
|
54 |
+
Language('sr', 'Serbian'),
|
55 |
+
Language('az', 'Azerbaijani'),
|
56 |
+
Language('sl', 'Slovenian'),
|
57 |
+
Language('kn', 'Kannada'),
|
58 |
+
Language('et', 'Estonian'),
|
59 |
+
Language('mk', 'Macedonian'),
|
60 |
+
Language('br', 'Breton'),
|
61 |
+
Language('eu', 'Basque'),
|
62 |
+
Language('is', 'Icelandic'),
|
63 |
+
Language('hy', 'Armenian'),
|
64 |
+
Language('ne', 'Nepali'),
|
65 |
+
Language('mn', 'Mongolian'),
|
66 |
+
Language('bs', 'Bosnian'),
|
67 |
+
Language('kk', 'Kazakh'),
|
68 |
+
Language('sq', 'Albanian'),
|
69 |
+
Language('sw', 'Swahili'),
|
70 |
+
Language('gl', 'Galician'),
|
71 |
+
Language('mr', 'Marathi'),
|
72 |
+
Language('pa', 'Punjabi'),
|
73 |
+
Language('si', 'Sinhala'),
|
74 |
+
Language('km', 'Khmer'),
|
75 |
+
Language('sn', 'Shona'),
|
76 |
+
Language('yo', 'Yoruba'),
|
77 |
+
Language('so', 'Somali'),
|
78 |
+
Language('af', 'Afrikaans'),
|
79 |
+
Language('oc', 'Occitan'),
|
80 |
+
Language('ka', 'Georgian'),
|
81 |
+
Language('be', 'Belarusian'),
|
82 |
+
Language('tg', 'Tajik'),
|
83 |
+
Language('sd', 'Sindhi'),
|
84 |
+
Language('gu', 'Gujarati'),
|
85 |
+
Language('am', 'Amharic'),
|
86 |
+
Language('yi', 'Yiddish'),
|
87 |
+
Language('lo', 'Lao'),
|
88 |
+
Language('uz', 'Uzbek'),
|
89 |
+
Language('fo', 'Faroese'),
|
90 |
+
Language('ht', 'Haitian creole'),
|
91 |
+
Language('ps', 'Pashto'),
|
92 |
+
Language('tk', 'Turkmen'),
|
93 |
+
Language('nn', 'Nynorsk'),
|
94 |
+
Language('mt', 'Maltese'),
|
95 |
+
Language('sa', 'Sanskrit'),
|
96 |
+
Language('lb', 'Luxembourgish'),
|
97 |
+
Language('my', 'Myanmar'),
|
98 |
+
Language('bo', 'Tibetan'),
|
99 |
+
Language('tl', 'Tagalog'),
|
100 |
+
Language('mg', 'Malagasy'),
|
101 |
+
Language('as', 'Assamese'),
|
102 |
+
Language('tt', 'Tatar'),
|
103 |
+
Language('haw', 'Hawaiian'),
|
104 |
+
Language('ln', 'Lingala'),
|
105 |
+
Language('ha', 'Hausa'),
|
106 |
+
Language('ba', 'Bashkir'),
|
107 |
+
Language('jw', 'Javanese'),
|
108 |
+
Language('su', 'Sundanese')
|
109 |
+
]
|
110 |
+
|
111 |
+
_TO_LANGUAGE_CODE = {
|
112 |
+
**{language.code: language for language in LANGUAGES},
|
113 |
+
"burmese": "my",
|
114 |
+
"valencian": "ca",
|
115 |
+
"flemish": "nl",
|
116 |
+
"haitian": "ht",
|
117 |
+
"letzeburgesch": "lb",
|
118 |
+
"pushto": "ps",
|
119 |
+
"panjabi": "pa",
|
120 |
+
"moldavian": "ro",
|
121 |
+
"moldovan": "ro",
|
122 |
+
"sinhalese": "si",
|
123 |
+
"castilian": "es",
|
124 |
+
}
|
125 |
+
|
126 |
+
_FROM_LANGUAGE_NAME = {
|
127 |
+
**{language.name.lower(): language for language in LANGUAGES}
|
128 |
+
}
|
129 |
+
|
130 |
+
def get_language_from_code(language_code, default=None) -> Language:
|
131 |
+
"""Return the language name from the language code."""
|
132 |
+
return _TO_LANGUAGE_CODE.get(language_code, default)
|
133 |
+
|
134 |
+
def get_language_from_name(language, default=None) -> Language:
|
135 |
+
"""Return the language code from the language name."""
|
136 |
+
return _FROM_LANGUAGE_NAME.get(language.lower() if language else None, default)
|
137 |
+
|
138 |
+
def get_language_names():
|
139 |
+
"""Return a list of language names."""
|
140 |
+
return [language.name for language in LANGUAGES]
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
# Test lookup
|
144 |
+
print(get_language_from_code('en'))
|
145 |
+
print(get_language_from_name('English'))
|
146 |
+
|
147 |
+
print(get_language_names())
|
src/modelCache.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class ModelCache:
|
2 |
+
def __init__(self):
|
3 |
+
self._cache = dict()
|
4 |
+
|
5 |
+
def get(self, model_key: str, model_factory):
|
6 |
+
result = self._cache.get(model_key)
|
7 |
+
|
8 |
+
if result is None:
|
9 |
+
result = model_factory()
|
10 |
+
self._cache[model_key] = result
|
11 |
+
return result
|
12 |
+
|
13 |
+
def clear(self):
|
14 |
+
self._cache.clear()
|
15 |
+
|
16 |
+
# A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
|
17 |
+
GLOBAL_MODEL_CACHE = ModelCache()
|
src/prompts/abstractPromptStrategy.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
|
3 |
+
|
4 |
+
class AbstractPromptStrategy:
|
5 |
+
"""
|
6 |
+
Represents a strategy for generating prompts for a given audio segment.
|
7 |
+
|
8 |
+
Note that the strategy must be picklable, as it will be serialized and sent to the workers.
|
9 |
+
"""
|
10 |
+
|
11 |
+
@abc.abstractmethod
|
12 |
+
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
|
13 |
+
"""
|
14 |
+
Retrieves the prompt for a given segment.
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
segment_index: int
|
19 |
+
The index of the segment.
|
20 |
+
whisper_prompt: str
|
21 |
+
The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
|
22 |
+
detected_language: str
|
23 |
+
The language detected for the segment.
|
24 |
+
"""
|
25 |
+
pass
|
26 |
+
|
27 |
+
@abc.abstractmethod
|
28 |
+
def on_segment_finished(self, segment_index: int, whisper_prompt: str, detected_language: str, result: dict):
|
29 |
+
"""
|
30 |
+
Called when a segment has finished processing.
|
31 |
+
|
32 |
+
Parameters
|
33 |
+
----------
|
34 |
+
segment_index: int
|
35 |
+
The index of the segment.
|
36 |
+
whisper_prompt: str
|
37 |
+
The prompt for the segment generated by Whisper. This is typically concatenated with the initial prompt.
|
38 |
+
detected_language: str
|
39 |
+
The language detected for the segment.
|
40 |
+
result: dict
|
41 |
+
The result of the segment. It has the following format:
|
42 |
+
{
|
43 |
+
"text": str,
|
44 |
+
"segments": [
|
45 |
+
{
|
46 |
+
"text": str,
|
47 |
+
"start": float,
|
48 |
+
"end": float,
|
49 |
+
"words": [words],
|
50 |
+
}
|
51 |
+
],
|
52 |
+
"language": str,
|
53 |
+
}
|
54 |
+
"""
|
55 |
+
pass
|
56 |
+
|
57 |
+
def _concat_prompt(self, prompt1, prompt2):
|
58 |
+
"""
|
59 |
+
Concatenates two prompts.
|
60 |
+
|
61 |
+
Parameters
|
62 |
+
----------
|
63 |
+
prompt1: str
|
64 |
+
The first prompt.
|
65 |
+
prompt2: str
|
66 |
+
The second prompt.
|
67 |
+
"""
|
68 |
+
if (prompt1 is None):
|
69 |
+
return prompt2
|
70 |
+
elif (prompt2 is None):
|
71 |
+
return prompt1
|
72 |
+
else:
|
73 |
+
return prompt1 + " " + prompt2
|
src/prompts/jsonPromptStrategy.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Dict
|
3 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
4 |
+
|
5 |
+
|
6 |
+
class JsonPromptSegment():
|
7 |
+
def __init__(self, segment_index: int, prompt: str, format_prompt: bool = False):
|
8 |
+
self.prompt = prompt
|
9 |
+
self.segment_index = segment_index
|
10 |
+
self.format_prompt = format_prompt
|
11 |
+
|
12 |
+
class JsonPromptStrategy(AbstractPromptStrategy):
|
13 |
+
def __init__(self, initial_json_prompt: str):
|
14 |
+
"""
|
15 |
+
Parameters
|
16 |
+
----------
|
17 |
+
initial_json_prompt: str
|
18 |
+
The initial prompts for each segment in JSON form.
|
19 |
+
|
20 |
+
Format:
|
21 |
+
[
|
22 |
+
{"segment_index": 0, "prompt": "Hello, how are you?"},
|
23 |
+
{"segment_index": 1, "prompt": "I'm doing well, how are you?"},
|
24 |
+
{"segment_index": 2, "prompt": "{0} Fine, thank you.", "format_prompt": true}
|
25 |
+
]
|
26 |
+
|
27 |
+
"""
|
28 |
+
parsed_json = json.loads(initial_json_prompt)
|
29 |
+
self.segment_lookup: Dict[str, JsonPromptSegment] = dict()
|
30 |
+
|
31 |
+
for prompt_entry in parsed_json:
|
32 |
+
segment_index = prompt_entry["segment_index"]
|
33 |
+
prompt = prompt_entry["prompt"]
|
34 |
+
format_prompt = prompt_entry.get("format_prompt", False)
|
35 |
+
self.segment_lookup[str(segment_index)] = JsonPromptSegment(segment_index, prompt, format_prompt)
|
36 |
+
|
37 |
+
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
|
38 |
+
# Lookup prompt
|
39 |
+
prompt = self.segment_lookup.get(str(segment_index), None)
|
40 |
+
|
41 |
+
if (prompt is None):
|
42 |
+
# No prompt found, return whisper prompt
|
43 |
+
print(f"Could not find prompt for segment {segment_index}, returning whisper prompt")
|
44 |
+
return whisper_prompt
|
45 |
+
|
46 |
+
if (prompt.format_prompt):
|
47 |
+
return prompt.prompt.format(whisper_prompt)
|
48 |
+
else:
|
49 |
+
return self._concat_prompt(prompt.prompt, whisper_prompt)
|
src/prompts/prependPromptStrategy.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from src.config import VadInitialPromptMode
|
2 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
3 |
+
|
4 |
+
class PrependPromptStrategy(AbstractPromptStrategy):
|
5 |
+
"""
|
6 |
+
A simple prompt strategy that prepends a single prompt to all segments of audio, or prepends the prompt to the first segment of audio.
|
7 |
+
"""
|
8 |
+
def __init__(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode):
|
9 |
+
"""
|
10 |
+
Parameters
|
11 |
+
----------
|
12 |
+
initial_prompt: str
|
13 |
+
The initial prompt to use for the transcription.
|
14 |
+
initial_prompt_mode: VadInitialPromptMode
|
15 |
+
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
16 |
+
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
17 |
+
"""
|
18 |
+
self.initial_prompt = initial_prompt
|
19 |
+
self.initial_prompt_mode = initial_prompt_mode
|
20 |
+
|
21 |
+
# This is a simple prompt strategy, so we only support these two modes
|
22 |
+
if initial_prompt_mode not in [VadInitialPromptMode.PREPEND_ALL_SEGMENTS, VadInitialPromptMode.PREPREND_FIRST_SEGMENT]:
|
23 |
+
raise ValueError(f"Unsupported initial prompt mode {initial_prompt_mode}")
|
24 |
+
|
25 |
+
def get_segment_prompt(self, segment_index: int, whisper_prompt: str, detected_language: str) -> str:
|
26 |
+
if (self.initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
|
27 |
+
return self._concat_prompt(self.initial_prompt, whisper_prompt)
|
28 |
+
elif (self.initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
29 |
+
return self._concat_prompt(self.initial_prompt, whisper_prompt) if segment_index == 0 else whisper_prompt
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unknown initial prompt mode {self.initial_prompt_mode}")
|
src/segments.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
|
3 |
+
import copy
|
4 |
+
|
5 |
+
def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
|
6 |
+
result = []
|
7 |
+
|
8 |
+
if len(timestamps) == 0:
|
9 |
+
return result
|
10 |
+
if max_merge_size is None:
|
11 |
+
return timestamps
|
12 |
+
|
13 |
+
if padding_left is None:
|
14 |
+
padding_left = 0
|
15 |
+
if padding_right is None:
|
16 |
+
padding_right = 0
|
17 |
+
|
18 |
+
processed_time = 0
|
19 |
+
current_segment = None
|
20 |
+
|
21 |
+
for i in range(len(timestamps)):
|
22 |
+
next_segment = timestamps[i]
|
23 |
+
|
24 |
+
delta = next_segment['start'] - processed_time
|
25 |
+
|
26 |
+
# Note that segments can still be longer than the max merge size, they just won't be merged in that case
|
27 |
+
if current_segment is None or (merge_window is not None and delta > merge_window) \
|
28 |
+
or next_segment['end'] - current_segment['start'] > max_merge_size:
|
29 |
+
# Finish the current segment
|
30 |
+
if current_segment is not None:
|
31 |
+
# Add right padding
|
32 |
+
finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
|
33 |
+
current_segment['end'] += finish_padding
|
34 |
+
delta -= finish_padding
|
35 |
+
|
36 |
+
result.append(current_segment)
|
37 |
+
|
38 |
+
# Start a new segment
|
39 |
+
current_segment = copy.deepcopy(next_segment)
|
40 |
+
|
41 |
+
# Pad the segment
|
42 |
+
current_segment['start'] = current_segment['start'] - min(padding_left, delta)
|
43 |
+
processed_time = current_segment['end']
|
44 |
+
|
45 |
+
else:
|
46 |
+
# Merge the segment
|
47 |
+
current_segment['end'] = next_segment['end']
|
48 |
+
processed_time = current_segment['end']
|
49 |
+
|
50 |
+
# Add the last segment
|
51 |
+
if current_segment is not None:
|
52 |
+
current_segment['end'] += padding_right
|
53 |
+
result.append(current_segment)
|
54 |
+
|
55 |
+
return result
|
src/source.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
from typing import List
|
5 |
+
import zipfile
|
6 |
+
|
7 |
+
import ffmpeg
|
8 |
+
from more_itertools import unzip
|
9 |
+
|
10 |
+
from src.download import ExceededMaximumDuration, download_url
|
11 |
+
|
12 |
+
MAX_FILE_PREFIX_LENGTH = 17
|
13 |
+
|
14 |
+
class AudioSource:
|
15 |
+
def __init__(self, source_path, source_name = None, audio_duration = None):
|
16 |
+
self.source_path = source_path
|
17 |
+
self.source_name = source_name
|
18 |
+
self._audio_duration = audio_duration
|
19 |
+
|
20 |
+
# Load source name if not provided
|
21 |
+
if (self.source_name is None):
|
22 |
+
file_path = pathlib.Path(self.source_path)
|
23 |
+
self.source_name = file_path.name
|
24 |
+
|
25 |
+
def get_audio_duration(self):
|
26 |
+
if self._audio_duration is None:
|
27 |
+
self._audio_duration = float(ffmpeg.probe(self.source_path)["format"]["duration"])
|
28 |
+
|
29 |
+
return self._audio_duration
|
30 |
+
|
31 |
+
def get_full_name(self):
|
32 |
+
return self.source_name
|
33 |
+
|
34 |
+
def get_short_name(self, max_length: int = MAX_FILE_PREFIX_LENGTH):
|
35 |
+
file_path = pathlib.Path(self.source_name)
|
36 |
+
short_name = file_path.stem[:max_length] + file_path.suffix
|
37 |
+
|
38 |
+
return short_name
|
39 |
+
|
40 |
+
def __str__(self) -> str:
|
41 |
+
return self.source_path
|
42 |
+
|
43 |
+
class AudioSourceCollection:
|
44 |
+
def __init__(self, sources: List[AudioSource]):
|
45 |
+
self.sources = sources
|
46 |
+
|
47 |
+
def __iter__(self):
|
48 |
+
return iter(self.sources)
|
49 |
+
|
50 |
+
def get_audio_source_collection(urlData: str, multipleFiles: List, microphoneData: str, input_audio_max_duration: float = -1) -> List[AudioSource]:
|
51 |
+
output: List[AudioSource] = []
|
52 |
+
|
53 |
+
if urlData:
|
54 |
+
# Download from YouTube. This could also be a playlist or a channel.
|
55 |
+
output.extend([ AudioSource(x) for x in download_url(urlData, input_audio_max_duration, playlistItems=None) ])
|
56 |
+
else:
|
57 |
+
# Add input files
|
58 |
+
if (multipleFiles is not None):
|
59 |
+
output.extend([ AudioSource(x.name) for x in multipleFiles ])
|
60 |
+
if (microphoneData is not None):
|
61 |
+
output.append(AudioSource(microphoneData))
|
62 |
+
|
63 |
+
total_duration = 0
|
64 |
+
|
65 |
+
# Calculate total audio length. We do this even if input_audio_max_duration
|
66 |
+
# is disabled to ensure that all the audio files are valid.
|
67 |
+
for source in output:
|
68 |
+
audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
|
69 |
+
total_duration += float(audioDuration)
|
70 |
+
|
71 |
+
# Save audio duration
|
72 |
+
source._audio_duration = float(audioDuration)
|
73 |
+
|
74 |
+
# Ensure the total duration of the audio is not too long
|
75 |
+
if input_audio_max_duration > 0:
|
76 |
+
if float(total_duration) > input_audio_max_duration:
|
77 |
+
raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
|
78 |
+
|
79 |
+
# Return a list of audio sources
|
80 |
+
return output
|
src/utils.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textwrap
|
2 |
+
import unicodedata
|
3 |
+
import re
|
4 |
+
|
5 |
+
import zlib
|
6 |
+
from typing import Iterator, TextIO, Union
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
import urllib3
|
10 |
+
|
11 |
+
|
12 |
+
def exact_div(x, y):
|
13 |
+
assert x % y == 0
|
14 |
+
return x // y
|
15 |
+
|
16 |
+
|
17 |
+
def str2bool(string):
|
18 |
+
str2val = {"True": True, "False": False}
|
19 |
+
if string in str2val:
|
20 |
+
return str2val[string]
|
21 |
+
else:
|
22 |
+
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
23 |
+
|
24 |
+
|
25 |
+
def optional_int(string):
|
26 |
+
return None if string == "None" else int(string)
|
27 |
+
|
28 |
+
|
29 |
+
def optional_float(string):
|
30 |
+
return None if string == "None" else float(string)
|
31 |
+
|
32 |
+
|
33 |
+
def compression_ratio(text) -> float:
|
34 |
+
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
35 |
+
|
36 |
+
|
37 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
|
38 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
39 |
+
milliseconds = round(seconds * 1000.0)
|
40 |
+
|
41 |
+
hours = milliseconds // 3_600_000
|
42 |
+
milliseconds -= hours * 3_600_000
|
43 |
+
|
44 |
+
minutes = milliseconds // 60_000
|
45 |
+
milliseconds -= minutes * 60_000
|
46 |
+
|
47 |
+
seconds = milliseconds // 1_000
|
48 |
+
milliseconds -= seconds * 1_000
|
49 |
+
|
50 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
51 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
|
52 |
+
|
53 |
+
|
54 |
+
def write_txt(transcript: Iterator[dict], file: TextIO):
|
55 |
+
for segment in transcript:
|
56 |
+
print(segment['text'].strip(), file=file, flush=True)
|
57 |
+
|
58 |
+
|
59 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO,
|
60 |
+
maxLineWidth=None, highlight_words: bool = False):
|
61 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
62 |
+
|
63 |
+
print("WEBVTT\n", file=file)
|
64 |
+
|
65 |
+
for segment in iterator:
|
66 |
+
text = segment['text'].replace('-->', '->')
|
67 |
+
|
68 |
+
print(
|
69 |
+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
70 |
+
f"{text}\n",
|
71 |
+
file=file,
|
72 |
+
flush=True,
|
73 |
+
)
|
74 |
+
|
75 |
+
def write_srt(transcript: Iterator[dict], file: TextIO,
|
76 |
+
maxLineWidth=None, highlight_words: bool = False):
|
77 |
+
"""
|
78 |
+
Write a transcript to a file in SRT format.
|
79 |
+
Example usage:
|
80 |
+
from pathlib import Path
|
81 |
+
from whisper.utils import write_srt
|
82 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
83 |
+
# save SRT
|
84 |
+
audio_basename = Path(audio_path).stem
|
85 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
86 |
+
write_srt(result["segments"], file=srt)
|
87 |
+
"""
|
88 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
89 |
+
|
90 |
+
for i, segment in enumerate(iterator, start=1):
|
91 |
+
text = segment['text'].replace('-->', '->')
|
92 |
+
|
93 |
+
# write srt lines
|
94 |
+
print(
|
95 |
+
f"{i}\n"
|
96 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
|
97 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
|
98 |
+
f"{text}\n",
|
99 |
+
file=file,
|
100 |
+
flush=True,
|
101 |
+
)
|
102 |
+
|
103 |
+
def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
|
104 |
+
for segment in transcript:
|
105 |
+
words = segment.get('words', [])
|
106 |
+
|
107 |
+
if len(words) == 0:
|
108 |
+
# Yield the segment as-is or processed
|
109 |
+
if maxLineWidth is None or maxLineWidth < 0:
|
110 |
+
yield segment
|
111 |
+
else:
|
112 |
+
yield {
|
113 |
+
'start': segment['start'],
|
114 |
+
'end': segment['end'],
|
115 |
+
'text': process_text(segment['text'].strip(), maxLineWidth)
|
116 |
+
}
|
117 |
+
# We are done
|
118 |
+
continue
|
119 |
+
|
120 |
+
subtitle_start = segment['start']
|
121 |
+
subtitle_end = segment['end']
|
122 |
+
|
123 |
+
text_words = [ this_word["word"] for this_word in words ]
|
124 |
+
subtitle_text = __join_words(text_words, maxLineWidth)
|
125 |
+
|
126 |
+
# Iterate over the words in the segment
|
127 |
+
if highlight_words:
|
128 |
+
last = subtitle_start
|
129 |
+
|
130 |
+
for i, this_word in enumerate(words):
|
131 |
+
start = this_word['start']
|
132 |
+
end = this_word['end']
|
133 |
+
|
134 |
+
if last != start:
|
135 |
+
# Display the text up to this point
|
136 |
+
yield {
|
137 |
+
'start': last,
|
138 |
+
'end': start,
|
139 |
+
'text': subtitle_text
|
140 |
+
}
|
141 |
+
|
142 |
+
# Display the text with the current word highlighted
|
143 |
+
yield {
|
144 |
+
'start': start,
|
145 |
+
'end': end,
|
146 |
+
'text': __join_words(
|
147 |
+
[
|
148 |
+
{
|
149 |
+
"word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
150 |
+
if j == i
|
151 |
+
else word,
|
152 |
+
# The HTML tags <u> and </u> are not displayed,
|
153 |
+
# # so they should not be counted in the word length
|
154 |
+
"length": len(word)
|
155 |
+
} for j, word in enumerate(text_words)
|
156 |
+
], maxLineWidth)
|
157 |
+
}
|
158 |
+
last = end
|
159 |
+
|
160 |
+
if last != subtitle_end:
|
161 |
+
# Display the last part of the text
|
162 |
+
yield {
|
163 |
+
'start': last,
|
164 |
+
'end': subtitle_end,
|
165 |
+
'text': subtitle_text
|
166 |
+
}
|
167 |
+
|
168 |
+
# Just return the subtitle text
|
169 |
+
else:
|
170 |
+
yield {
|
171 |
+
'start': subtitle_start,
|
172 |
+
'end': subtitle_end,
|
173 |
+
'text': subtitle_text
|
174 |
+
}
|
175 |
+
|
176 |
+
def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
|
177 |
+
if maxLineWidth is None or maxLineWidth < 0:
|
178 |
+
return " ".join(words)
|
179 |
+
|
180 |
+
lines = []
|
181 |
+
current_line = ""
|
182 |
+
current_length = 0
|
183 |
+
|
184 |
+
for entry in words:
|
185 |
+
# Either accept a string or a dict with a 'word' and 'length' field
|
186 |
+
if isinstance(entry, dict):
|
187 |
+
word = entry['word']
|
188 |
+
word_length = entry['length']
|
189 |
+
else:
|
190 |
+
word = entry
|
191 |
+
word_length = len(word)
|
192 |
+
|
193 |
+
if current_length > 0 and current_length + word_length > maxLineWidth:
|
194 |
+
lines.append(current_line)
|
195 |
+
current_line = ""
|
196 |
+
current_length = 0
|
197 |
+
|
198 |
+
current_length += word_length
|
199 |
+
# The word will be prefixed with a space by Whisper, so we don't need to add one here
|
200 |
+
current_line += word
|
201 |
+
|
202 |
+
if len(current_line) > 0:
|
203 |
+
lines.append(current_line)
|
204 |
+
|
205 |
+
return "\n".join(lines)
|
206 |
+
|
207 |
+
def process_text(text: str, maxLineWidth=None):
|
208 |
+
if (maxLineWidth is None or maxLineWidth < 0):
|
209 |
+
return text
|
210 |
+
|
211 |
+
lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
|
212 |
+
return '\n'.join(lines)
|
213 |
+
|
214 |
+
def slugify(value, allow_unicode=False):
|
215 |
+
"""
|
216 |
+
Taken from https://github.com/django/django/blob/master/django/utils/text.py
|
217 |
+
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
218 |
+
dashes to single dashes. Remove characters that aren't alphanumerics,
|
219 |
+
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
220 |
+
trailing whitespace, dashes, and underscores.
|
221 |
+
"""
|
222 |
+
value = str(value)
|
223 |
+
if allow_unicode:
|
224 |
+
value = unicodedata.normalize('NFKC', value)
|
225 |
+
else:
|
226 |
+
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
227 |
+
value = re.sub(r'[^\w\s-]', '', value.lower())
|
228 |
+
return re.sub(r'[-\s]+', '-', value).strip('-_')
|
229 |
+
|
230 |
+
def download_file(url: str, destination: str):
|
231 |
+
with urllib3.request.urlopen(url) as source, open(destination, "wb") as output:
|
232 |
+
with tqdm(
|
233 |
+
total=int(source.info().get("Content-Length")),
|
234 |
+
ncols=80,
|
235 |
+
unit="iB",
|
236 |
+
unit_scale=True,
|
237 |
+
unit_divisor=1024,
|
238 |
+
) as loop:
|
239 |
+
while True:
|
240 |
+
buffer = source.read(8192)
|
241 |
+
if not buffer:
|
242 |
+
break
|
243 |
+
|
244 |
+
output.write(buffer)
|
245 |
+
loop.update(len(buffer))
|
src/vad.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from collections import Counter, deque
|
3 |
+
import time
|
4 |
+
|
5 |
+
from typing import Any, Deque, Iterator, List, Dict
|
6 |
+
|
7 |
+
from pprint import pprint
|
8 |
+
from src.hooks.progressListener import ProgressListener
|
9 |
+
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
10 |
+
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
11 |
+
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
12 |
+
|
13 |
+
from src.segments import merge_timestamps
|
14 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
|
15 |
+
|
16 |
+
# Workaround for https://github.com/tensorflow/tensorflow/issues/48797
|
17 |
+
try:
|
18 |
+
import tensorflow as tf
|
19 |
+
except ModuleNotFoundError:
|
20 |
+
# Error handling
|
21 |
+
pass
|
22 |
+
|
23 |
+
import torch
|
24 |
+
|
25 |
+
import ffmpeg
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
from src.utils import format_timestamp
|
29 |
+
from enum import Enum
|
30 |
+
|
31 |
+
class NonSpeechStrategy(Enum):
|
32 |
+
"""
|
33 |
+
Ignore non-speech frames segments.
|
34 |
+
"""
|
35 |
+
SKIP = 1
|
36 |
+
"""
|
37 |
+
Just treat non-speech segments as speech.
|
38 |
+
"""
|
39 |
+
CREATE_SEGMENT = 2
|
40 |
+
"""
|
41 |
+
Expand speech segments into subsequent non-speech segments.
|
42 |
+
"""
|
43 |
+
EXPAND_SEGMENT = 3
|
44 |
+
|
45 |
+
# Defaults for Silero
|
46 |
+
SPEECH_TRESHOLD = 0.3
|
47 |
+
|
48 |
+
# Minimum size of segments to process
|
49 |
+
MIN_SEGMENT_DURATION = 1
|
50 |
+
|
51 |
+
# The maximum time for texts from old segments to be used in the next segment
|
52 |
+
MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
|
53 |
+
PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
|
54 |
+
|
55 |
+
VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
|
56 |
+
|
57 |
+
class TranscriptionConfig(ABC):
|
58 |
+
def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
59 |
+
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
60 |
+
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
|
61 |
+
self.non_speech_strategy = non_speech_strategy
|
62 |
+
self.segment_padding_left = segment_padding_left
|
63 |
+
self.segment_padding_right = segment_padding_right
|
64 |
+
self.max_silent_period = max_silent_period
|
65 |
+
self.max_merge_size = max_merge_size
|
66 |
+
self.max_prompt_window = max_prompt_window
|
67 |
+
self.initial_segment_index = initial_segment_index
|
68 |
+
|
69 |
+
class PeriodicTranscriptionConfig(TranscriptionConfig):
|
70 |
+
def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
|
71 |
+
segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
|
72 |
+
max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
|
73 |
+
super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
|
74 |
+
self.periodic_duration = periodic_duration
|
75 |
+
|
76 |
+
class AbstractTranscription(ABC):
|
77 |
+
def __init__(self, sampling_rate: int = 16000):
|
78 |
+
self.sampling_rate = sampling_rate
|
79 |
+
|
80 |
+
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
81 |
+
return load_audio(str, self.sampling_rate, start_time, duration)
|
82 |
+
|
83 |
+
def is_transcribe_timestamps_fast(self):
|
84 |
+
"""
|
85 |
+
Determine if get_transcribe_timestamps is fast enough to not need parallelization.
|
86 |
+
"""
|
87 |
+
return False
|
88 |
+
|
89 |
+
@abstractmethod
|
90 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
|
91 |
+
"""
|
92 |
+
Get the start and end timestamps of the sections that should be transcribed by this VAD method.
|
93 |
+
|
94 |
+
Parameters
|
95 |
+
----------
|
96 |
+
audio: str
|
97 |
+
The audio file.
|
98 |
+
config: TranscriptionConfig
|
99 |
+
The transcription configuration.
|
100 |
+
|
101 |
+
Returns
|
102 |
+
-------
|
103 |
+
A list of start and end timestamps, in fractional seconds.
|
104 |
+
"""
|
105 |
+
return
|
106 |
+
|
107 |
+
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
|
108 |
+
"""
|
109 |
+
Get the start and end timestamps of the sections that should be transcribed by this VAD method,
|
110 |
+
after merging the given segments using the specified configuration.
|
111 |
+
|
112 |
+
Parameters
|
113 |
+
----------
|
114 |
+
audio: str
|
115 |
+
The audio file.
|
116 |
+
config: TranscriptionConfig
|
117 |
+
The transcription configuration.
|
118 |
+
|
119 |
+
Returns
|
120 |
+
-------
|
121 |
+
A list of start and end timestamps, in fractional seconds.
|
122 |
+
"""
|
123 |
+
merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
|
124 |
+
config.segment_padding_left, config.segment_padding_right)
|
125 |
+
|
126 |
+
if config.non_speech_strategy != NonSpeechStrategy.SKIP:
|
127 |
+
# Expand segments to include the gaps between them
|
128 |
+
if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
|
129 |
+
# When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
|
130 |
+
merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
|
131 |
+
elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
|
132 |
+
# With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
|
133 |
+
merged = self.expand_gaps(merged, total_duration=total_duration)
|
134 |
+
else:
|
135 |
+
raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
|
136 |
+
|
137 |
+
print("Transcribing non-speech:")
|
138 |
+
pprint(merged)
|
139 |
+
return merged
|
140 |
+
|
141 |
+
def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
|
142 |
+
progressListener: ProgressListener = None):
|
143 |
+
"""
|
144 |
+
Transcribe the given audo file.
|
145 |
+
|
146 |
+
Parameters
|
147 |
+
----------
|
148 |
+
audio: str
|
149 |
+
The audio file.
|
150 |
+
whisperCallable: WhisperCallback
|
151 |
+
A callback object to call to transcribe each segment.
|
152 |
+
|
153 |
+
Returns
|
154 |
+
-------
|
155 |
+
A list of start and end timestamps, in fractional seconds.
|
156 |
+
"""
|
157 |
+
|
158 |
+
try:
|
159 |
+
max_audio_duration = self.get_audio_duration(audio, config)
|
160 |
+
timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
|
161 |
+
|
162 |
+
# Get speech timestamps from full audio file
|
163 |
+
merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
|
164 |
+
|
165 |
+
# A deque of transcribed segments that is passed to the next segment as a prompt
|
166 |
+
prompt_window = deque()
|
167 |
+
|
168 |
+
print("Processing timestamps:")
|
169 |
+
pprint(merged)
|
170 |
+
|
171 |
+
result = {
|
172 |
+
'text': "",
|
173 |
+
'segments': [],
|
174 |
+
'language': ""
|
175 |
+
}
|
176 |
+
languageCounter = Counter()
|
177 |
+
detected_language = None
|
178 |
+
|
179 |
+
segment_index = config.initial_segment_index
|
180 |
+
|
181 |
+
# Calculate progress
|
182 |
+
progress_start_offset = merged[0]['start'] if len(merged) > 0 else 0
|
183 |
+
progress_total_duration = sum([segment['end'] - segment['start'] for segment in merged])
|
184 |
+
|
185 |
+
# For each time segment, run whisper
|
186 |
+
for segment in merged:
|
187 |
+
segment_index += 1
|
188 |
+
segment_start = segment['start']
|
189 |
+
segment_end = segment['end']
|
190 |
+
segment_expand_amount = segment.get('expand_amount', 0)
|
191 |
+
segment_gap = segment.get('gap', False)
|
192 |
+
|
193 |
+
segment_duration = segment_end - segment_start
|
194 |
+
|
195 |
+
if segment_duration < MIN_SEGMENT_DURATION:
|
196 |
+
continue
|
197 |
+
|
198 |
+
# Audio to run on Whisper
|
199 |
+
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
|
200 |
+
# Previous segments to use as a prompt
|
201 |
+
segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
|
202 |
+
|
203 |
+
# Detected language
|
204 |
+
detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
|
205 |
+
|
206 |
+
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
|
207 |
+
segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
|
208 |
+
|
209 |
+
perf_start_time = time.perf_counter()
|
210 |
+
|
211 |
+
scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=progress_total_duration,
|
212 |
+
sub_task_start=segment_start - progress_start_offset, sub_task_total=segment_duration)
|
213 |
+
segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
|
214 |
+
|
215 |
+
perf_end_time = time.perf_counter()
|
216 |
+
print("Whisper took {} seconds".format(perf_end_time - perf_start_time))
|
217 |
+
|
218 |
+
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
|
219 |
+
|
220 |
+
# Propagate expand amount to the segments
|
221 |
+
if (segment_expand_amount > 0):
|
222 |
+
segment_without_expansion = segment_duration - segment_expand_amount
|
223 |
+
|
224 |
+
for adjusted_segment in adjusted_segments:
|
225 |
+
adjusted_segment_end = adjusted_segment['end']
|
226 |
+
|
227 |
+
# Add expand amount if the segment got expanded
|
228 |
+
if (adjusted_segment_end > segment_without_expansion):
|
229 |
+
adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
|
230 |
+
|
231 |
+
# Append to output
|
232 |
+
result['text'] += segment_result['text']
|
233 |
+
result['segments'].extend(adjusted_segments)
|
234 |
+
|
235 |
+
# Increment detected language
|
236 |
+
if not segment_gap:
|
237 |
+
languageCounter[segment_result['language']] += 1
|
238 |
+
|
239 |
+
# Update prompt window
|
240 |
+
self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
|
241 |
+
|
242 |
+
if detected_language is not None:
|
243 |
+
result['language'] = detected_language
|
244 |
+
finally:
|
245 |
+
# Notify progress listener that we are done
|
246 |
+
if progressListener is not None:
|
247 |
+
progressListener.on_finished()
|
248 |
+
return result
|
249 |
+
|
250 |
+
def get_audio_duration(self, audio: str, config: TranscriptionConfig):
|
251 |
+
return get_audio_duration(audio)
|
252 |
+
|
253 |
+
def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
|
254 |
+
if (config.max_prompt_window is not None and config.max_prompt_window > 0):
|
255 |
+
# Add segments to the current prompt window (unless it is a speech gap)
|
256 |
+
if not segment_gap:
|
257 |
+
for segment in adjusted_segments:
|
258 |
+
if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
|
259 |
+
prompt_window.append(segment)
|
260 |
+
|
261 |
+
while (len(prompt_window) > 0):
|
262 |
+
first_end_time = prompt_window[0].get('end', 0)
|
263 |
+
# Time expanded in the segments should be discounted from the prompt window
|
264 |
+
first_expand_time = prompt_window[0].get('expand_amount', 0)
|
265 |
+
|
266 |
+
if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
|
267 |
+
prompt_window.popleft()
|
268 |
+
else:
|
269 |
+
break
|
270 |
+
|
271 |
+
def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
|
272 |
+
result = []
|
273 |
+
last_end_time = 0
|
274 |
+
|
275 |
+
for segment in segments:
|
276 |
+
segment_start = float(segment['start'])
|
277 |
+
segment_end = float(segment['end'])
|
278 |
+
|
279 |
+
if (last_end_time != segment_start):
|
280 |
+
delta = segment_start - last_end_time
|
281 |
+
|
282 |
+
if (min_gap_length is None or delta >= min_gap_length):
|
283 |
+
result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
|
284 |
+
|
285 |
+
last_end_time = segment_end
|
286 |
+
result.append(segment)
|
287 |
+
|
288 |
+
# Also include total duration if specified
|
289 |
+
if (total_duration is not None and last_end_time < total_duration):
|
290 |
+
delta = total_duration - segment_start
|
291 |
+
|
292 |
+
if (min_gap_length is None or delta >= min_gap_length):
|
293 |
+
result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
|
294 |
+
|
295 |
+
return result
|
296 |
+
|
297 |
+
# Expand the end time of each segment to the start of the next segment
|
298 |
+
def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
|
299 |
+
result = []
|
300 |
+
|
301 |
+
if len(segments) == 0:
|
302 |
+
return result
|
303 |
+
|
304 |
+
# Add gap at the beginning if needed
|
305 |
+
if (segments[0]['start'] > 0):
|
306 |
+
result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
|
307 |
+
|
308 |
+
for i in range(len(segments) - 1):
|
309 |
+
current_segment = segments[i]
|
310 |
+
next_segment = segments[i + 1]
|
311 |
+
|
312 |
+
delta = next_segment['start'] - current_segment['end']
|
313 |
+
|
314 |
+
# Expand if the gap actually exists
|
315 |
+
if (delta >= 0):
|
316 |
+
current_segment = current_segment.copy()
|
317 |
+
current_segment['expand_amount'] = delta
|
318 |
+
current_segment['end'] = next_segment['start']
|
319 |
+
|
320 |
+
result.append(current_segment)
|
321 |
+
|
322 |
+
# Add last segment
|
323 |
+
last_segment = segments[-1]
|
324 |
+
result.append(last_segment)
|
325 |
+
|
326 |
+
# Also include total duration if specified
|
327 |
+
if (total_duration is not None):
|
328 |
+
last_segment = result[-1]
|
329 |
+
|
330 |
+
if (last_segment['end'] < total_duration):
|
331 |
+
last_segment = last_segment.copy()
|
332 |
+
last_segment['end'] = total_duration
|
333 |
+
result[-1] = last_segment
|
334 |
+
|
335 |
+
return result
|
336 |
+
|
337 |
+
def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
|
338 |
+
result = []
|
339 |
+
|
340 |
+
if len(segments) == 0:
|
341 |
+
return result
|
342 |
+
|
343 |
+
# Add gap at the beginning if needed
|
344 |
+
if (segments[0]['start'] > 0):
|
345 |
+
result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
|
346 |
+
|
347 |
+
for i in range(len(segments) - 1):
|
348 |
+
expanded = False
|
349 |
+
current_segment = segments[i]
|
350 |
+
next_segment = segments[i + 1]
|
351 |
+
|
352 |
+
delta = next_segment['start'] - current_segment['end']
|
353 |
+
|
354 |
+
if (max_expand_size is not None and delta <= max_expand_size):
|
355 |
+
# Just expand the current segment
|
356 |
+
current_segment = current_segment.copy()
|
357 |
+
current_segment['expand_amount'] = delta
|
358 |
+
current_segment['end'] = next_segment['start']
|
359 |
+
expanded = True
|
360 |
+
|
361 |
+
result.append(current_segment)
|
362 |
+
|
363 |
+
# Add a gap to the next segment if needed
|
364 |
+
if (delta >= 0 and not expanded):
|
365 |
+
result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
|
366 |
+
|
367 |
+
# Add last segment
|
368 |
+
last_segment = segments[-1]
|
369 |
+
result.append(last_segment)
|
370 |
+
|
371 |
+
# Also include total duration if specified
|
372 |
+
if (total_duration is not None):
|
373 |
+
last_segment = result[-1]
|
374 |
+
|
375 |
+
delta = total_duration - last_segment['end']
|
376 |
+
|
377 |
+
if (delta > 0):
|
378 |
+
if (max_expand_size is not None and delta <= max_expand_size):
|
379 |
+
# Expand the last segment
|
380 |
+
last_segment = last_segment.copy()
|
381 |
+
last_segment['expand_amount'] = delta
|
382 |
+
last_segment['end'] = total_duration
|
383 |
+
result[-1] = last_segment
|
384 |
+
else:
|
385 |
+
result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
|
386 |
+
|
387 |
+
return result
|
388 |
+
|
389 |
+
def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
|
390 |
+
result = []
|
391 |
+
|
392 |
+
for segment in segments:
|
393 |
+
segment_start = float(segment['start'])
|
394 |
+
segment_end = float(segment['end'])
|
395 |
+
|
396 |
+
# Filter segments?
|
397 |
+
if (max_source_time is not None):
|
398 |
+
if (segment_start > max_source_time):
|
399 |
+
continue
|
400 |
+
segment_end = min(max_source_time, segment_end)
|
401 |
+
|
402 |
+
new_segment = segment.copy()
|
403 |
+
|
404 |
+
# Add to start and end
|
405 |
+
new_segment['start'] = segment_start + adjust_seconds
|
406 |
+
new_segment['end'] = segment_end + adjust_seconds
|
407 |
+
|
408 |
+
# Handle words
|
409 |
+
if ('words' in new_segment):
|
410 |
+
for word in new_segment['words']:
|
411 |
+
# Adjust start and end
|
412 |
+
word['start'] = word['start'] + adjust_seconds
|
413 |
+
word['end'] = word['end'] + adjust_seconds
|
414 |
+
|
415 |
+
result.append(new_segment)
|
416 |
+
return result
|
417 |
+
|
418 |
+
def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
|
419 |
+
result = []
|
420 |
+
|
421 |
+
for entry in timestamps:
|
422 |
+
start = entry['start']
|
423 |
+
end = entry['end']
|
424 |
+
|
425 |
+
result.append({
|
426 |
+
'start': start * factor,
|
427 |
+
'end': end * factor
|
428 |
+
})
|
429 |
+
return result
|
430 |
+
|
431 |
+
|
432 |
+
class VadSileroTranscription(AbstractTranscription):
|
433 |
+
def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
|
434 |
+
super().__init__(sampling_rate=sampling_rate)
|
435 |
+
self.model = None
|
436 |
+
self.cache = cache
|
437 |
+
self._initialize_model()
|
438 |
+
|
439 |
+
def _initialize_model(self):
|
440 |
+
if (self.cache is not None):
|
441 |
+
model_key = "VadSileroTranscription"
|
442 |
+
self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
|
443 |
+
print("Loaded Silerio model from cache.")
|
444 |
+
else:
|
445 |
+
self.model, self.get_speech_timestamps = self._create_model()
|
446 |
+
print("Created Silerio model")
|
447 |
+
|
448 |
+
def _create_model(self):
|
449 |
+
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
|
450 |
+
|
451 |
+
# Silero does not benefit from multi-threading
|
452 |
+
torch.set_num_threads(1) # JIT
|
453 |
+
(get_speech_timestamps, _, _, _, _) = utils
|
454 |
+
|
455 |
+
return model, get_speech_timestamps
|
456 |
+
|
457 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
|
458 |
+
result = []
|
459 |
+
|
460 |
+
print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
|
461 |
+
perf_start_time = time.perf_counter()
|
462 |
+
|
463 |
+
# Divide procesisng of audio into chunks
|
464 |
+
chunk_start = start_time
|
465 |
+
|
466 |
+
while (chunk_start < end_time):
|
467 |
+
chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
|
468 |
+
|
469 |
+
print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
|
470 |
+
wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
|
471 |
+
|
472 |
+
sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
|
473 |
+
seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
|
474 |
+
adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
|
475 |
+
|
476 |
+
#pprint(adjusted)
|
477 |
+
|
478 |
+
result.extend(adjusted)
|
479 |
+
chunk_start += chunk_duration
|
480 |
+
|
481 |
+
perf_end_time = time.perf_counter()
|
482 |
+
print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
|
483 |
+
|
484 |
+
return result
|
485 |
+
|
486 |
+
def __getstate__(self):
|
487 |
+
# We only need the sampling rate
|
488 |
+
return { 'sampling_rate': self.sampling_rate }
|
489 |
+
|
490 |
+
def __setstate__(self, state):
|
491 |
+
self.sampling_rate = state['sampling_rate']
|
492 |
+
self.model = None
|
493 |
+
# Use the global cache
|
494 |
+
self.cache = GLOBAL_MODEL_CACHE
|
495 |
+
self._initialize_model()
|
496 |
+
|
497 |
+
# A very simple VAD that just marks every N seconds as speech
|
498 |
+
class VadPeriodicTranscription(AbstractTranscription):
|
499 |
+
def __init__(self, sampling_rate: int = 16000):
|
500 |
+
super().__init__(sampling_rate=sampling_rate)
|
501 |
+
|
502 |
+
def is_transcribe_timestamps_fast(self):
|
503 |
+
# This is a very fast VAD - no need to parallelize it
|
504 |
+
return True
|
505 |
+
|
506 |
+
def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
|
507 |
+
result = []
|
508 |
+
|
509 |
+
# Generate a timestamp every N seconds
|
510 |
+
start_timestamp = start_time
|
511 |
+
|
512 |
+
while (start_timestamp < end_time):
|
513 |
+
end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
|
514 |
+
segment_duration = end_timestamp - start_timestamp
|
515 |
+
|
516 |
+
# Minimum duration is 1 second
|
517 |
+
if (segment_duration >= 1):
|
518 |
+
result.append( { 'start': start_timestamp, 'end': end_timestamp } )
|
519 |
+
|
520 |
+
start_timestamp = end_timestamp
|
521 |
+
|
522 |
+
return result
|
523 |
+
|
524 |
+
def get_audio_duration(file: str):
|
525 |
+
return float(ffmpeg.probe(file)["format"]["duration"])
|
526 |
+
|
527 |
+
def load_audio(file: str, sample_rate: int = 16000,
|
528 |
+
start_time: str = None, duration: str = None):
|
529 |
+
"""
|
530 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
531 |
+
|
532 |
+
Parameters
|
533 |
+
----------
|
534 |
+
file: str
|
535 |
+
The audio file to open
|
536 |
+
|
537 |
+
sr: int
|
538 |
+
The sample rate to resample the audio if necessary
|
539 |
+
|
540 |
+
start_time: str
|
541 |
+
The start time, using the standard FFMPEG time duration syntax, or None to disable.
|
542 |
+
|
543 |
+
duration: str
|
544 |
+
The duration, using the standard FFMPEG time duration syntax, or None to disable.
|
545 |
+
|
546 |
+
Returns
|
547 |
+
-------
|
548 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
549 |
+
"""
|
550 |
+
try:
|
551 |
+
inputArgs = {'threads': 0}
|
552 |
+
|
553 |
+
if (start_time is not None):
|
554 |
+
inputArgs['ss'] = start_time
|
555 |
+
if (duration is not None):
|
556 |
+
inputArgs['t'] = duration
|
557 |
+
|
558 |
+
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
559 |
+
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
560 |
+
out, _ = (
|
561 |
+
ffmpeg.input(file, **inputArgs)
|
562 |
+
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
|
563 |
+
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
|
564 |
+
)
|
565 |
+
except ffmpeg.Error as e:
|
566 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
|
567 |
+
|
568 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
src/vadParallel.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import multiprocessing
|
2 |
+
from queue import Empty
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
from src.hooks.progressListener import ProgressListener
|
6 |
+
from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
|
7 |
+
|
8 |
+
from multiprocessing import Pool, Queue
|
9 |
+
|
10 |
+
from typing import Any, Dict, List, Union
|
11 |
+
import os
|
12 |
+
|
13 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback
|
14 |
+
|
15 |
+
class _ProgressListenerToQueue(ProgressListener):
|
16 |
+
def __init__(self, progress_queue: Queue):
|
17 |
+
self.progress_queue = progress_queue
|
18 |
+
self.progress_total = 0
|
19 |
+
self.prev_progress = 0
|
20 |
+
|
21 |
+
def on_progress(self, current: Union[int, float], total: Union[int, float]):
|
22 |
+
delta = current - self.prev_progress
|
23 |
+
self.prev_progress = current
|
24 |
+
self.progress_total = total
|
25 |
+
self.progress_queue.put(delta)
|
26 |
+
|
27 |
+
def on_finished(self):
|
28 |
+
if self.progress_total > self.prev_progress:
|
29 |
+
delta = self.progress_total - self.prev_progress
|
30 |
+
self.progress_queue.put(delta)
|
31 |
+
self.prev_progress = self.progress_total
|
32 |
+
|
33 |
+
class ParallelContext:
|
34 |
+
def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
|
35 |
+
self.num_processes = num_processes
|
36 |
+
self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
|
37 |
+
self.lock = threading.Lock()
|
38 |
+
|
39 |
+
self.ref_count = 0
|
40 |
+
self.pool = None
|
41 |
+
self.cleanup_timer = None
|
42 |
+
|
43 |
+
def get_pool(self):
|
44 |
+
# Initialize pool lazily
|
45 |
+
if (self.pool is None):
|
46 |
+
context = multiprocessing.get_context('spawn')
|
47 |
+
self.pool = context.Pool(self.num_processes)
|
48 |
+
|
49 |
+
self.ref_count = self.ref_count + 1
|
50 |
+
|
51 |
+
if (self.auto_cleanup_timeout_seconds is not None):
|
52 |
+
self._stop_auto_cleanup()
|
53 |
+
|
54 |
+
return self.pool
|
55 |
+
|
56 |
+
def return_pool(self, pool):
|
57 |
+
if (self.pool == pool and self.ref_count > 0):
|
58 |
+
self.ref_count = self.ref_count - 1
|
59 |
+
|
60 |
+
if (self.ref_count == 0):
|
61 |
+
if (self.auto_cleanup_timeout_seconds is not None):
|
62 |
+
self._start_auto_cleanup()
|
63 |
+
|
64 |
+
def _start_auto_cleanup(self):
|
65 |
+
if (self.cleanup_timer is not None):
|
66 |
+
self.cleanup_timer.cancel()
|
67 |
+
self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup)
|
68 |
+
self.cleanup_timer.start()
|
69 |
+
|
70 |
+
print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds")
|
71 |
+
|
72 |
+
def _stop_auto_cleanup(self):
|
73 |
+
if (self.cleanup_timer is not None):
|
74 |
+
self.cleanup_timer.cancel()
|
75 |
+
self.cleanup_timer = None
|
76 |
+
|
77 |
+
print("Stopped auto cleanup of pool")
|
78 |
+
|
79 |
+
def _execute_cleanup(self):
|
80 |
+
print("Executing cleanup of pool")
|
81 |
+
|
82 |
+
if (self.ref_count == 0):
|
83 |
+
self.close()
|
84 |
+
|
85 |
+
def close(self):
|
86 |
+
self._stop_auto_cleanup()
|
87 |
+
|
88 |
+
if (self.pool is not None):
|
89 |
+
print("Closing pool of " + str(self.num_processes) + " processes")
|
90 |
+
self.pool.close()
|
91 |
+
self.pool.join()
|
92 |
+
self.pool = None
|
93 |
+
|
94 |
+
class ParallelTranscriptionConfig(TranscriptionConfig):
|
95 |
+
def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
|
96 |
+
super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
|
97 |
+
self.device_id = device_id
|
98 |
+
self.override_timestamps = override_timestamps
|
99 |
+
|
100 |
+
class ParallelTranscription(AbstractTranscription):
|
101 |
+
# Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
|
102 |
+
# into smaller segments than 2 minute (min 6 seconds per CPU core)
|
103 |
+
MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
|
104 |
+
|
105 |
+
def __init__(self, sampling_rate: int = 16000):
|
106 |
+
super().__init__(sampling_rate=sampling_rate)
|
107 |
+
|
108 |
+
def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig,
|
109 |
+
cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
|
110 |
+
progress_listener: ProgressListener = None):
|
111 |
+
total_duration = get_audio_duration(audio)
|
112 |
+
|
113 |
+
# First, get the timestamps for the original audio
|
114 |
+
if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()):
|
115 |
+
merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
|
116 |
+
else:
|
117 |
+
timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
|
118 |
+
merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
|
119 |
+
|
120 |
+
# We must make sure the whisper model is downloaded
|
121 |
+
if (len(gpu_devices) > 1):
|
122 |
+
whisperCallable.model_container.ensure_downloaded()
|
123 |
+
|
124 |
+
# Split into a list for each device
|
125 |
+
# TODO: Split by time instead of by number of chunks
|
126 |
+
merged_split = list(self._split(merged, len(gpu_devices)))
|
127 |
+
|
128 |
+
# Parameters that will be passed to the transcribe function
|
129 |
+
parameters = []
|
130 |
+
segment_index = config.initial_segment_index
|
131 |
+
|
132 |
+
processing_manager = multiprocessing.Manager()
|
133 |
+
progress_queue = processing_manager.Queue()
|
134 |
+
|
135 |
+
for i in range(len(gpu_devices)):
|
136 |
+
# Note that device_segment_list can be empty. But we will still create a process for it,
|
137 |
+
# as otherwise we run the risk of assigning the same device to multiple processes.
|
138 |
+
device_segment_list = list(merged_split[i]) if i < len(merged_split) else []
|
139 |
+
device_id = gpu_devices[i]
|
140 |
+
|
141 |
+
print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
|
142 |
+
|
143 |
+
# Create a new config with the given device ID
|
144 |
+
device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
|
145 |
+
segment_index += len(device_segment_list)
|
146 |
+
|
147 |
+
progress_listener_to_queue = _ProgressListenerToQueue(progress_queue)
|
148 |
+
parameters.append([audio, whisperCallable, device_config, progress_listener_to_queue]);
|
149 |
+
|
150 |
+
merged = {
|
151 |
+
'text': '',
|
152 |
+
'segments': [],
|
153 |
+
'language': None
|
154 |
+
}
|
155 |
+
|
156 |
+
created_context = False
|
157 |
+
|
158 |
+
perf_start_gpu = time.perf_counter()
|
159 |
+
|
160 |
+
# Spawn a separate process for each device
|
161 |
+
try:
|
162 |
+
if (gpu_parallel_context is None):
|
163 |
+
gpu_parallel_context = ParallelContext(len(gpu_devices))
|
164 |
+
created_context = True
|
165 |
+
|
166 |
+
# Get a pool of processes
|
167 |
+
pool = gpu_parallel_context.get_pool()
|
168 |
+
|
169 |
+
# Run the transcription in parallel
|
170 |
+
results_async = pool.starmap_async(self.transcribe, parameters)
|
171 |
+
total_progress = 0
|
172 |
+
|
173 |
+
while not results_async.ready():
|
174 |
+
try:
|
175 |
+
delta = progress_queue.get(timeout=5) # Set a timeout of 5 seconds
|
176 |
+
except Empty:
|
177 |
+
continue
|
178 |
+
|
179 |
+
total_progress += delta
|
180 |
+
if progress_listener is not None:
|
181 |
+
progress_listener.on_progress(total_progress, total_duration)
|
182 |
+
|
183 |
+
results = results_async.get()
|
184 |
+
|
185 |
+
# Call the finished callback
|
186 |
+
if progress_listener is not None:
|
187 |
+
progress_listener.on_finished()
|
188 |
+
|
189 |
+
for result in results:
|
190 |
+
# Merge the results
|
191 |
+
if (result['text'] is not None):
|
192 |
+
merged['text'] += result['text']
|
193 |
+
if (result['segments'] is not None):
|
194 |
+
merged['segments'].extend(result['segments'])
|
195 |
+
if (result['language'] is not None):
|
196 |
+
merged['language'] = result['language']
|
197 |
+
|
198 |
+
finally:
|
199 |
+
# Return the pool to the context
|
200 |
+
if (gpu_parallel_context is not None):
|
201 |
+
gpu_parallel_context.return_pool(pool)
|
202 |
+
# Always close the context if we created it
|
203 |
+
if (created_context):
|
204 |
+
gpu_parallel_context.close()
|
205 |
+
|
206 |
+
perf_end_gpu = time.perf_counter()
|
207 |
+
print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
|
208 |
+
|
209 |
+
return merged
|
210 |
+
|
211 |
+
def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
|
212 |
+
cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
|
213 |
+
parameters = []
|
214 |
+
|
215 |
+
chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
|
216 |
+
chunk_start = 0
|
217 |
+
cpu_device_id = 0
|
218 |
+
|
219 |
+
perf_start_time = time.perf_counter()
|
220 |
+
|
221 |
+
# Create chunks that will be processed on the CPU
|
222 |
+
while (chunk_start < total_duration):
|
223 |
+
chunk_end = min(chunk_start + chunk_size, total_duration)
|
224 |
+
|
225 |
+
if (chunk_end - chunk_start < 1):
|
226 |
+
# No need to process chunks that are less than 1 second
|
227 |
+
break
|
228 |
+
|
229 |
+
print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
|
230 |
+
str(chunk_end) + " on CPU device " + str(cpu_device_id))
|
231 |
+
parameters.append([audio, config, chunk_start, chunk_end]);
|
232 |
+
|
233 |
+
cpu_device_id += 1
|
234 |
+
chunk_start = chunk_end
|
235 |
+
|
236 |
+
created_context = False
|
237 |
+
|
238 |
+
# Spawn a separate process for each device
|
239 |
+
try:
|
240 |
+
if (cpu_parallel_context is None):
|
241 |
+
cpu_parallel_context = ParallelContext(cpu_device_count)
|
242 |
+
created_context = True
|
243 |
+
|
244 |
+
# Get a pool of processes
|
245 |
+
pool = cpu_parallel_context.get_pool()
|
246 |
+
|
247 |
+
# Run the transcription in parallel. Note that transcription must be picklable.
|
248 |
+
results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
|
249 |
+
|
250 |
+
timestamps = []
|
251 |
+
|
252 |
+
# Flatten the results
|
253 |
+
for result in results:
|
254 |
+
timestamps.extend(result)
|
255 |
+
|
256 |
+
merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
|
257 |
+
|
258 |
+
perf_end_time = time.perf_counter()
|
259 |
+
print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
|
260 |
+
return merged
|
261 |
+
|
262 |
+
finally:
|
263 |
+
# Return the pool to the context
|
264 |
+
if (cpu_parallel_context is not None):
|
265 |
+
cpu_parallel_context.return_pool(pool)
|
266 |
+
# Always close the context if we created it
|
267 |
+
if (created_context):
|
268 |
+
cpu_parallel_context.close()
|
269 |
+
|
270 |
+
def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
|
271 |
+
return []
|
272 |
+
|
273 |
+
def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
|
274 |
+
# Override timestamps that will be processed
|
275 |
+
if (config.override_timestamps is not None):
|
276 |
+
print("(get_merged_timestamps) Using override timestamps of size " + str(len(config.override_timestamps)))
|
277 |
+
return config.override_timestamps
|
278 |
+
return super().get_merged_timestamps(timestamps, config, total_duration)
|
279 |
+
|
280 |
+
def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: ParallelTranscriptionConfig,
|
281 |
+
progressListener: ProgressListener = None):
|
282 |
+
# Override device ID the first time
|
283 |
+
if (os.environ.get("INITIALIZED", None) is None):
|
284 |
+
os.environ["INITIALIZED"] = "1"
|
285 |
+
|
286 |
+
# Note that this may be None if the user didn't specify a device. In that case, Whisper will
|
287 |
+
# just use the default GPU device.
|
288 |
+
if (config.device_id is not None):
|
289 |
+
print("Using device " + config.device_id)
|
290 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
|
291 |
+
|
292 |
+
return super().transcribe(audio, whisperCallable, config, progressListener)
|
293 |
+
|
294 |
+
def _split(self, a, n):
|
295 |
+
"""Split a list into n approximately equal parts."""
|
296 |
+
k, m = divmod(len(a), n)
|
297 |
+
return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
|
298 |
+
|
src/whisper/abstractWhisperContainer.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from src.config import ModelConfig, VadInitialPromptMode
|
5 |
+
|
6 |
+
from src.hooks.progressListener import ProgressListener
|
7 |
+
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
8 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
9 |
+
|
10 |
+
class AbstractWhisperCallback:
|
11 |
+
def __init__(self):
|
12 |
+
self.__prompt_mode_gpt = None
|
13 |
+
|
14 |
+
@abc.abstractmethod
|
15 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
16 |
+
"""
|
17 |
+
Peform the transcription of the given audio file or data.
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
----------
|
21 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
22 |
+
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
|
23 |
+
segment_index: int
|
24 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
25 |
+
task: str
|
26 |
+
The task - either translate or transcribe.
|
27 |
+
progress_listener: ProgressListener
|
28 |
+
A callback to receive progress updates.
|
29 |
+
"""
|
30 |
+
raise NotImplementedError()
|
31 |
+
|
32 |
+
class AbstractWhisperContainer:
|
33 |
+
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
34 |
+
download_root: str = None,
|
35 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
36 |
+
self.model_name = model_name
|
37 |
+
self.device = device
|
38 |
+
self.compute_type = compute_type
|
39 |
+
self.download_root = download_root
|
40 |
+
self.cache = cache
|
41 |
+
|
42 |
+
# Will be created on demand
|
43 |
+
self.model = None
|
44 |
+
|
45 |
+
# List of known models
|
46 |
+
self.models = models
|
47 |
+
|
48 |
+
def get_model(self):
|
49 |
+
if self.model is None:
|
50 |
+
|
51 |
+
if (self.cache is None):
|
52 |
+
self.model = self._create_model()
|
53 |
+
else:
|
54 |
+
model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
|
55 |
+
self.model = self.cache.get(model_key, self._create_model)
|
56 |
+
return self.model
|
57 |
+
|
58 |
+
@abc.abstractmethod
|
59 |
+
def _create_model(self):
|
60 |
+
raise NotImplementedError()
|
61 |
+
|
62 |
+
def ensure_downloaded(self):
|
63 |
+
pass
|
64 |
+
|
65 |
+
@abc.abstractmethod
|
66 |
+
def create_callback(self, language: str = None, task: str = None,
|
67 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
68 |
+
**decodeOptions: dict) -> AbstractWhisperCallback:
|
69 |
+
"""
|
70 |
+
Create a WhisperCallback object that can be used to transcript audio files.
|
71 |
+
|
72 |
+
Parameters
|
73 |
+
----------
|
74 |
+
language: str
|
75 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
76 |
+
task: str
|
77 |
+
The task - either translate or transcribe.
|
78 |
+
prompt_strategy: AbstractPromptStrategy
|
79 |
+
The prompt strategy to use for the transcription.
|
80 |
+
decodeOptions: dict
|
81 |
+
Additional options to pass to the decoder. Must be pickleable.
|
82 |
+
|
83 |
+
Returns
|
84 |
+
-------
|
85 |
+
A WhisperCallback object.
|
86 |
+
"""
|
87 |
+
raise NotImplementedError()
|
88 |
+
|
89 |
+
# This is required for multiprocessing
|
90 |
+
def __getstate__(self):
|
91 |
+
return {
|
92 |
+
"model_name": self.model_name,
|
93 |
+
"device": self.device,
|
94 |
+
"download_root": self.download_root,
|
95 |
+
"models": self.models,
|
96 |
+
"compute_type": self.compute_type
|
97 |
+
}
|
98 |
+
|
99 |
+
def __setstate__(self, state):
|
100 |
+
self.model_name = state["model_name"]
|
101 |
+
self.device = state["device"]
|
102 |
+
self.download_root = state["download_root"]
|
103 |
+
self.models = state["models"]
|
104 |
+
self.compute_type = state["compute_type"]
|
105 |
+
self.model = None
|
106 |
+
# Depickled objects must use the global cache
|
107 |
+
self.cache = GLOBAL_MODEL_CACHE
|
src/whisper/fasterWhisperContainer.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
from faster_whisper import WhisperModel, download_model
|
5 |
+
from src.config import ModelConfig, VadInitialPromptMode
|
6 |
+
from src.hooks.progressListener import ProgressListener
|
7 |
+
from src.languages import get_language_from_name
|
8 |
+
from src.modelCache import ModelCache
|
9 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
10 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
11 |
+
from src.utils import format_timestamp
|
12 |
+
|
13 |
+
class FasterWhisperContainer(AbstractWhisperContainer):
|
14 |
+
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
15 |
+
download_root: str = None,
|
16 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
17 |
+
super().__init__(model_name, device, compute_type, download_root, cache, models)
|
18 |
+
|
19 |
+
def ensure_downloaded(self):
|
20 |
+
"""
|
21 |
+
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
|
22 |
+
passing the container to a subprocess.
|
23 |
+
"""
|
24 |
+
model_config = self._get_model_config()
|
25 |
+
|
26 |
+
if os.path.isdir(model_config.url):
|
27 |
+
model_config.path = model_config.url
|
28 |
+
else:
|
29 |
+
model_config.path = download_model(model_config.url, output_dir=self.download_root)
|
30 |
+
|
31 |
+
def _get_model_config(self) -> ModelConfig:
|
32 |
+
"""
|
33 |
+
Get the model configuration for the model.
|
34 |
+
"""
|
35 |
+
for model in self.models:
|
36 |
+
if model.name == self.model_name:
|
37 |
+
return model
|
38 |
+
return None
|
39 |
+
|
40 |
+
def _create_model(self):
|
41 |
+
print("Loading faster whisper model " + self.model_name + " for device " + str(self.device))
|
42 |
+
model_config = self._get_model_config()
|
43 |
+
model_url = model_config.url
|
44 |
+
|
45 |
+
if model_config.type == "whisper":
|
46 |
+
if model_url not in ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]:
|
47 |
+
raise Exception("FasterWhisperContainer does not yet support Whisper models. Use ct2-transformers-converter to convert the model to a faster-whisper model.")
|
48 |
+
if model_url == "large":
|
49 |
+
# large is an alias for large-v1
|
50 |
+
model_url = "large-v1"
|
51 |
+
|
52 |
+
device = self.device
|
53 |
+
|
54 |
+
if (device is None):
|
55 |
+
device = "auto"
|
56 |
+
|
57 |
+
model = WhisperModel(model_url, device=device, compute_type=self.compute_type)
|
58 |
+
return model
|
59 |
+
|
60 |
+
def create_callback(self, language: str = None, task: str = None,
|
61 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
62 |
+
**decodeOptions: dict) -> AbstractWhisperCallback:
|
63 |
+
"""
|
64 |
+
Create a WhisperCallback object that can be used to transcript audio files.
|
65 |
+
|
66 |
+
Parameters
|
67 |
+
----------
|
68 |
+
language: str
|
69 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
70 |
+
task: str
|
71 |
+
The task - either translate or transcribe.
|
72 |
+
prompt_strategy: AbstractPromptStrategy
|
73 |
+
The prompt strategy to use. If not specified, the prompt from Whisper will be used.
|
74 |
+
decodeOptions: dict
|
75 |
+
Additional options to pass to the decoder. Must be pickleable.
|
76 |
+
|
77 |
+
Returns
|
78 |
+
-------
|
79 |
+
A WhisperCallback object.
|
80 |
+
"""
|
81 |
+
return FasterWhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
|
82 |
+
|
83 |
+
class FasterWhisperCallback(AbstractWhisperCallback):
|
84 |
+
def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
|
85 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
86 |
+
**decodeOptions: dict):
|
87 |
+
self.model_container = model_container
|
88 |
+
self.language = language
|
89 |
+
self.task = task
|
90 |
+
self.prompt_strategy = prompt_strategy
|
91 |
+
self.decodeOptions = decodeOptions
|
92 |
+
|
93 |
+
self._printed_warning = False
|
94 |
+
|
95 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
96 |
+
"""
|
97 |
+
Peform the transcription of the given audio file or data.
|
98 |
+
|
99 |
+
Parameters
|
100 |
+
----------
|
101 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
102 |
+
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
|
103 |
+
segment_index: int
|
104 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
105 |
+
task: str
|
106 |
+
The task - either translate or transcribe.
|
107 |
+
progress_listener: ProgressListener
|
108 |
+
A callback to receive progress updates.
|
109 |
+
"""
|
110 |
+
model: WhisperModel = self.model_container.get_model()
|
111 |
+
language_code = self._lookup_language_code(self.language) if self.language else None
|
112 |
+
|
113 |
+
# Copy decode options and remove options that are not supported by faster-whisper
|
114 |
+
decodeOptions = self.decodeOptions.copy()
|
115 |
+
verbose = decodeOptions.pop("verbose", None)
|
116 |
+
|
117 |
+
logprob_threshold = decodeOptions.pop("logprob_threshold", None)
|
118 |
+
|
119 |
+
patience = decodeOptions.pop("patience", None)
|
120 |
+
length_penalty = decodeOptions.pop("length_penalty", None)
|
121 |
+
suppress_tokens = decodeOptions.pop("suppress_tokens", None)
|
122 |
+
|
123 |
+
if (decodeOptions.pop("fp16", None) is not None):
|
124 |
+
if not self._printed_warning:
|
125 |
+
print("WARNING: fp16 option is ignored by faster-whisper - use compute_type instead.")
|
126 |
+
self._printed_warning = True
|
127 |
+
|
128 |
+
# Fix up decode options
|
129 |
+
if (logprob_threshold is not None):
|
130 |
+
decodeOptions["log_prob_threshold"] = logprob_threshold
|
131 |
+
|
132 |
+
decodeOptions["patience"] = float(patience) if patience is not None else 1.0
|
133 |
+
decodeOptions["length_penalty"] = float(length_penalty) if length_penalty is not None else 1.0
|
134 |
+
|
135 |
+
# See if supress_tokens is a string - if so, convert it to a list of ints
|
136 |
+
decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
|
137 |
+
|
138 |
+
initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
|
139 |
+
if self.prompt_strategy else prompt
|
140 |
+
|
141 |
+
segments_generator, info = model.transcribe(audio, \
|
142 |
+
language=language_code if language_code else detected_language, task=self.task, \
|
143 |
+
initial_prompt=initial_prompt, \
|
144 |
+
**decodeOptions
|
145 |
+
)
|
146 |
+
|
147 |
+
segments = []
|
148 |
+
|
149 |
+
for segment in segments_generator:
|
150 |
+
segments.append(segment)
|
151 |
+
|
152 |
+
if progress_listener is not None:
|
153 |
+
progress_listener.on_progress(segment.end, info.duration)
|
154 |
+
if verbose:
|
155 |
+
print("[{}->{}] {}".format(format_timestamp(segment.start, True), format_timestamp(segment.end, True),
|
156 |
+
segment.text))
|
157 |
+
|
158 |
+
text = " ".join([segment.text for segment in segments])
|
159 |
+
|
160 |
+
# Convert the segments to a format that is easier to serialize
|
161 |
+
whisper_segments = [{
|
162 |
+
"text": segment.text,
|
163 |
+
"start": segment.start,
|
164 |
+
"end": segment.end,
|
165 |
+
|
166 |
+
# Extra fields added by faster-whisper
|
167 |
+
"words": [{
|
168 |
+
"start": word.start,
|
169 |
+
"end": word.end,
|
170 |
+
"word": word.word,
|
171 |
+
"probability": word.probability
|
172 |
+
} for word in (segment.words if segment.words is not None else []) ]
|
173 |
+
} for segment in segments]
|
174 |
+
|
175 |
+
result = {
|
176 |
+
"segments": whisper_segments,
|
177 |
+
"text": text,
|
178 |
+
"language": info.language if info else None,
|
179 |
+
|
180 |
+
# Extra fields added by faster-whisper
|
181 |
+
"language_probability": info.language_probability if info else None,
|
182 |
+
"duration": info.duration if info else None
|
183 |
+
}
|
184 |
+
|
185 |
+
# If we have a prompt strategy, we need to increment the current prompt
|
186 |
+
if self.prompt_strategy:
|
187 |
+
self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
|
188 |
+
|
189 |
+
if progress_listener is not None:
|
190 |
+
progress_listener.on_finished()
|
191 |
+
return result
|
192 |
+
|
193 |
+
def _split_suppress_tokens(self, suppress_tokens: Union[str, List[int]]):
|
194 |
+
if (suppress_tokens is None):
|
195 |
+
return None
|
196 |
+
if (isinstance(suppress_tokens, list)):
|
197 |
+
return suppress_tokens
|
198 |
+
|
199 |
+
return [int(token) for token in suppress_tokens.split(",")]
|
200 |
+
|
201 |
+
def _lookup_language_code(self, language: str):
|
202 |
+
language = get_language_from_name(language)
|
203 |
+
|
204 |
+
if language is None:
|
205 |
+
raise ValueError("Invalid language: " + language)
|
206 |
+
|
207 |
+
return language.code
|
src/whisper/whisperContainer.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# External programs
|
2 |
+
import abc
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from typing import List
|
6 |
+
from urllib.parse import urlparse
|
7 |
+
import torch
|
8 |
+
import urllib3
|
9 |
+
from src.hooks.progressListener import ProgressListener
|
10 |
+
|
11 |
+
import whisper
|
12 |
+
from whisper import Whisper
|
13 |
+
|
14 |
+
from src.config import ModelConfig, VadInitialPromptMode
|
15 |
+
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
16 |
+
|
17 |
+
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
18 |
+
from src.prompts.abstractPromptStrategy import AbstractPromptStrategy
|
19 |
+
from src.utils import download_file
|
20 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperCallback, AbstractWhisperContainer
|
21 |
+
|
22 |
+
class WhisperContainer(AbstractWhisperContainer):
|
23 |
+
def __init__(self, model_name: str, device: str = None, compute_type: str = "float16",
|
24 |
+
download_root: str = None,
|
25 |
+
cache: ModelCache = None, models: List[ModelConfig] = []):
|
26 |
+
if device is None:
|
27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
+
super().__init__(model_name, device, compute_type, download_root, cache, models)
|
29 |
+
|
30 |
+
def ensure_downloaded(self):
|
31 |
+
"""
|
32 |
+
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
|
33 |
+
passing the container to a subprocess.
|
34 |
+
"""
|
35 |
+
# Warning: Using private API here
|
36 |
+
try:
|
37 |
+
root_dir = self.download_root
|
38 |
+
model_config = self._get_model_config()
|
39 |
+
|
40 |
+
if root_dir is None:
|
41 |
+
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
42 |
+
|
43 |
+
if self.model_name in whisper._MODELS:
|
44 |
+
whisper._download(whisper._MODELS[self.model_name], root_dir, False)
|
45 |
+
else:
|
46 |
+
# If the model is not in the official list, see if it needs to be downloaded
|
47 |
+
model_config.download_url(root_dir)
|
48 |
+
return True
|
49 |
+
|
50 |
+
except Exception as e:
|
51 |
+
# Given that the API is private, it could change at any time. We don't want to crash the program
|
52 |
+
print("Error pre-downloading model: " + str(e))
|
53 |
+
return False
|
54 |
+
|
55 |
+
def _get_model_config(self) -> ModelConfig:
|
56 |
+
"""
|
57 |
+
Get the model configuration for the model.
|
58 |
+
"""
|
59 |
+
for model in self.models:
|
60 |
+
if model.name == self.model_name:
|
61 |
+
return model
|
62 |
+
return None
|
63 |
+
|
64 |
+
def _create_model(self):
|
65 |
+
print("Loading whisper model " + self.model_name)
|
66 |
+
model_config = self._get_model_config()
|
67 |
+
|
68 |
+
# Note that the model will not be downloaded in the case of an official Whisper model
|
69 |
+
model_path = self._get_model_path(model_config, self.download_root)
|
70 |
+
|
71 |
+
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
72 |
+
|
73 |
+
def create_callback(self, language: str = None, task: str = None,
|
74 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
75 |
+
**decodeOptions: dict) -> AbstractWhisperCallback:
|
76 |
+
"""
|
77 |
+
Create a WhisperCallback object that can be used to transcript audio files.
|
78 |
+
|
79 |
+
Parameters
|
80 |
+
----------
|
81 |
+
language: str
|
82 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
83 |
+
task: str
|
84 |
+
The task - either translate or transcribe.
|
85 |
+
prompt_strategy: AbstractPromptStrategy
|
86 |
+
The prompt strategy to use. If not specified, the prompt from Whisper will be used.
|
87 |
+
decodeOptions: dict
|
88 |
+
Additional options to pass to the decoder. Must be pickleable.
|
89 |
+
|
90 |
+
Returns
|
91 |
+
-------
|
92 |
+
A WhisperCallback object.
|
93 |
+
"""
|
94 |
+
return WhisperCallback(self, language=language, task=task, prompt_strategy=prompt_strategy, **decodeOptions)
|
95 |
+
|
96 |
+
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
97 |
+
from src.conversion.hf_converter import convert_hf_whisper
|
98 |
+
"""
|
99 |
+
Download the model.
|
100 |
+
|
101 |
+
Parameters
|
102 |
+
----------
|
103 |
+
model_config: ModelConfig
|
104 |
+
The model configuration.
|
105 |
+
"""
|
106 |
+
# See if path is already set
|
107 |
+
if model_config.path is not None:
|
108 |
+
return model_config.path
|
109 |
+
|
110 |
+
if root_dir is None:
|
111 |
+
root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
112 |
+
|
113 |
+
model_type = model_config.type.lower() if model_config.type is not None else "whisper"
|
114 |
+
|
115 |
+
if model_type in ["huggingface", "hf"]:
|
116 |
+
model_config.path = model_config.url
|
117 |
+
destination_target = os.path.join(root_dir, model_config.name + ".pt")
|
118 |
+
|
119 |
+
# Convert from HuggingFace format to Whisper format
|
120 |
+
if os.path.exists(destination_target):
|
121 |
+
print(f"File {destination_target} already exists, skipping conversion")
|
122 |
+
else:
|
123 |
+
print("Saving HuggingFace model in Whisper format to " + destination_target)
|
124 |
+
convert_hf_whisper(model_config.url, destination_target)
|
125 |
+
|
126 |
+
model_config.path = destination_target
|
127 |
+
|
128 |
+
elif model_type in ["whisper", "w"]:
|
129 |
+
model_config.path = model_config.url
|
130 |
+
|
131 |
+
# See if URL is just a file
|
132 |
+
if model_config.url in whisper._MODELS:
|
133 |
+
# No need to download anything - Whisper will handle it
|
134 |
+
model_config.path = model_config.url
|
135 |
+
elif model_config.url.startswith("file://"):
|
136 |
+
# Get file path
|
137 |
+
model_config.path = urlparse(model_config.url).path
|
138 |
+
# See if it is an URL
|
139 |
+
elif model_config.url.startswith("http://") or model_config.url.startswith("https://"):
|
140 |
+
# Extension (or file name)
|
141 |
+
extension = os.path.splitext(model_config.url)[-1]
|
142 |
+
download_target = os.path.join(root_dir, model_config.name + extension)
|
143 |
+
|
144 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
145 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
146 |
+
|
147 |
+
if not os.path.isfile(download_target):
|
148 |
+
download_file(model_config.url, download_target)
|
149 |
+
else:
|
150 |
+
print(f"File {download_target} already exists, skipping download")
|
151 |
+
|
152 |
+
model_config.path = download_target
|
153 |
+
# Must be a local file
|
154 |
+
else:
|
155 |
+
model_config.path = model_config.url
|
156 |
+
|
157 |
+
else:
|
158 |
+
raise ValueError(f"Unknown model type {model_type}")
|
159 |
+
|
160 |
+
return model_config.path
|
161 |
+
|
162 |
+
class WhisperCallback(AbstractWhisperCallback):
|
163 |
+
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None,
|
164 |
+
prompt_strategy: AbstractPromptStrategy = None,
|
165 |
+
**decodeOptions: dict):
|
166 |
+
self.model_container = model_container
|
167 |
+
self.language = language
|
168 |
+
self.task = task
|
169 |
+
self.prompt_strategy = prompt_strategy
|
170 |
+
|
171 |
+
self.decodeOptions = decodeOptions
|
172 |
+
|
173 |
+
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
174 |
+
"""
|
175 |
+
Peform the transcription of the given audio file or data.
|
176 |
+
|
177 |
+
Parameters
|
178 |
+
----------
|
179 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
180 |
+
The audio file to transcribe, or the audio data as a numpy array or torch tensor.
|
181 |
+
segment_index: int
|
182 |
+
The target language of the transcription. If not specified, the language will be inferred from the audio content.
|
183 |
+
task: str
|
184 |
+
The task - either translate or transcribe.
|
185 |
+
progress_listener: ProgressListener
|
186 |
+
A callback to receive progress updates.
|
187 |
+
"""
|
188 |
+
model = self.model_container.get_model()
|
189 |
+
|
190 |
+
if progress_listener is not None:
|
191 |
+
with create_progress_listener_handle(progress_listener):
|
192 |
+
return self._transcribe(model, audio, segment_index, prompt, detected_language)
|
193 |
+
else:
|
194 |
+
return self._transcribe(model, audio, segment_index, prompt, detected_language)
|
195 |
+
|
196 |
+
def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
|
197 |
+
decodeOptions = self.decodeOptions.copy()
|
198 |
+
|
199 |
+
# Add fp16
|
200 |
+
if self.model_container.compute_type in ["fp16", "float16"]:
|
201 |
+
decodeOptions["fp16"] = True
|
202 |
+
|
203 |
+
initial_prompt = self.prompt_strategy.get_segment_prompt(segment_index, prompt, detected_language) \
|
204 |
+
if self.prompt_strategy else prompt
|
205 |
+
|
206 |
+
result = model.transcribe(audio, \
|
207 |
+
language=self.language if self.language else detected_language, task=self.task, \
|
208 |
+
initial_prompt=initial_prompt, \
|
209 |
+
**decodeOptions
|
210 |
+
)
|
211 |
+
|
212 |
+
# If we have a prompt strategy, we need to increment the current prompt
|
213 |
+
if self.prompt_strategy:
|
214 |
+
self.prompt_strategy.on_segment_finished(segment_index, prompt, detected_language, result)
|
215 |
+
|
216 |
+
return result
|
src/whisper/whisperFactory.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from src import modelCache
|
3 |
+
from src.config import ModelConfig
|
4 |
+
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
5 |
+
|
6 |
+
def create_whisper_container(whisper_implementation: str,
|
7 |
+
model_name: str, device: str = None, compute_type: str = "float16",
|
8 |
+
download_root: str = None,
|
9 |
+
cache: modelCache = None, models: List[ModelConfig] = []) -> AbstractWhisperContainer:
|
10 |
+
print("Creating whisper container for " + whisper_implementation)
|
11 |
+
|
12 |
+
if (whisper_implementation == "whisper"):
|
13 |
+
from src.whisper.whisperContainer import WhisperContainer
|
14 |
+
return WhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
|
15 |
+
elif (whisper_implementation == "faster-whisper" or whisper_implementation == "faster_whisper"):
|
16 |
+
from src.whisper.fasterWhisperContainer import FasterWhisperContainer
|
17 |
+
return FasterWhisperContainer(model_name=model_name, device=device, compute_type=compute_type, download_root=download_root, cache=cache, models=models)
|
18 |
+
else:
|
19 |
+
raise ValueError("Unknown Whisper implementation: " + whisper_implementation)
|
tests/segments_test.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import unittest
|
3 |
+
|
4 |
+
sys.path.append('../whisper-webui')
|
5 |
+
|
6 |
+
from src.segments import merge_timestamps
|
7 |
+
|
8 |
+
class TestSegments(unittest.TestCase):
|
9 |
+
def __init__(self, *args, **kwargs):
|
10 |
+
super(TestSegments, self).__init__(*args, **kwargs)
|
11 |
+
|
12 |
+
def test_merge_segments(self):
|
13 |
+
segments = [
|
14 |
+
{'start': 10.0, 'end': 20.0},
|
15 |
+
{'start': 22.0, 'end': 27.0},
|
16 |
+
{'start': 31.0, 'end': 35.0},
|
17 |
+
{'start': 45.0, 'end': 60.0},
|
18 |
+
{'start': 61.0, 'end': 65.0},
|
19 |
+
{'start': 68.0, 'end': 98.0},
|
20 |
+
{'start': 100.0, 'end': 102.0},
|
21 |
+
{'start': 110.0, 'end': 112.0}
|
22 |
+
]
|
23 |
+
|
24 |
+
result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
|
25 |
+
|
26 |
+
self.assertListEqual(result, [
|
27 |
+
{'start': 9.0, 'end': 36.0},
|
28 |
+
{'start': 44.0, 'end': 66.0},
|
29 |
+
{'start': 67.0, 'end': 99.0},
|
30 |
+
{'start': 99.0, 'end': 103.0},
|
31 |
+
{'start': 109.0, 'end': 113.0}
|
32 |
+
])
|
33 |
+
|
34 |
+
def test_overlap_next(self):
|
35 |
+
segments = [
|
36 |
+
{'start': 5.0, 'end': 39.182},
|
37 |
+
{'start': 39.986, 'end': 40.814}
|
38 |
+
]
|
39 |
+
|
40 |
+
result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
|
41 |
+
|
42 |
+
self.assertListEqual(result, [
|
43 |
+
{'start': 4.0, 'end': 39.584},
|
44 |
+
{'start': 39.584, 'end': 41.814}
|
45 |
+
])
|
46 |
+
|
47 |
+
if __name__ == '__main__':
|
48 |
+
unittest.main()
|
tests/vad_test.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pprint
|
2 |
+
import unittest
|
3 |
+
import numpy as np
|
4 |
+
import sys
|
5 |
+
|
6 |
+
sys.path.append('../whisper-webui')
|
7 |
+
|
8 |
+
from src.vad import AbstractTranscription, TranscriptionConfig, VadSileroTranscription
|
9 |
+
|
10 |
+
class TestVad(unittest.TestCase):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super(TestVad, self).__init__(*args, **kwargs)
|
13 |
+
self.transcribe_calls = []
|
14 |
+
|
15 |
+
def test_transcript(self):
|
16 |
+
mock = MockVadTranscription()
|
17 |
+
|
18 |
+
self.transcribe_calls.clear()
|
19 |
+
result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
|
20 |
+
|
21 |
+
self.assertListEqual(self.transcribe_calls, [
|
22 |
+
[30, 30],
|
23 |
+
[100, 100]
|
24 |
+
])
|
25 |
+
|
26 |
+
self.assertListEqual(result['segments'],
|
27 |
+
[{'end': 50.0, 'start': 40.0, 'text': 'Hello world '},
|
28 |
+
{'end': 120.0, 'start': 110.0, 'text': 'Hello world '}]
|
29 |
+
)
|
30 |
+
|
31 |
+
def transcribe_segments(self, segment):
|
32 |
+
self.transcribe_calls.append(segment.tolist())
|
33 |
+
|
34 |
+
# Dummy text
|
35 |
+
return {
|
36 |
+
'text': "Hello world ",
|
37 |
+
'segments': [
|
38 |
+
{
|
39 |
+
"start": 10.0,
|
40 |
+
"end": 20.0,
|
41 |
+
"text": "Hello world "
|
42 |
+
}
|
43 |
+
],
|
44 |
+
'language': ""
|
45 |
+
}
|
46 |
+
|
47 |
+
class MockVadTranscription(AbstractTranscription):
|
48 |
+
def __init__(self):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
52 |
+
start_time_seconds = float(start_time.removesuffix("s"))
|
53 |
+
duration_seconds = float(duration.removesuffix("s"))
|
54 |
+
|
55 |
+
# For mocking, this just returns a simple numppy array
|
56 |
+
return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
|
57 |
+
|
58 |
+
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, duration: float):
|
59 |
+
result = []
|
60 |
+
|
61 |
+
result.append( { 'start': 30, 'end': 60 } )
|
62 |
+
result.append( { 'start': 100, 'end': 200 } )
|
63 |
+
return result
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
unittest.main()
|