Spaces:
Runtime error
Runtime error
Daryl Fung
commited on
Commit
•
9a3c8f5
1
Parent(s):
932db78
added generation audio
Browse files- db/audio_db/is3/is3.py +6 -2
- db/audio_db/is3/wrapper.py +3 -12
- db/generate_audio.py +22 -12
db/audio_db/is3/is3.py
CHANGED
@@ -23,6 +23,7 @@ class UploadedObject(BaseModel):
|
|
23 |
obj_id: ObjectId
|
24 |
deletehash: str
|
25 |
cached_obj: Any = None
|
|
|
26 |
|
27 |
def __getstate__(self):
|
28 |
d = super().__getstate__()
|
@@ -63,13 +64,14 @@ class StagedObject(BaseModel):
|
|
63 |
during the same runtime does not need to download the object.
|
64 |
"""
|
65 |
async with Imgur() as imgur:
|
66 |
-
oid, delete = await imgur.upload_image(self.image())
|
67 |
|
68 |
return UploadedObject(
|
69 |
name=self.name,
|
70 |
obj_id=oid,
|
71 |
deletehash=delete,
|
72 |
-
cached_obj=self.obj
|
|
|
73 |
)
|
74 |
|
75 |
|
@@ -122,6 +124,8 @@ class Bucket:
|
|
122 |
'\n'.join(o.name for o in self.pending.values())
|
123 |
)
|
124 |
raise Warning(msg)
|
|
|
|
|
125 |
|
126 |
async def get_obj(self, name: str) -> Any:
|
127 |
return await self.uploaded[name].download()
|
|
|
23 |
obj_id: ObjectId
|
24 |
deletehash: str
|
25 |
cached_obj: Any = None
|
26 |
+
link: str
|
27 |
|
28 |
def __getstate__(self):
|
29 |
d = super().__getstate__()
|
|
|
64 |
during the same runtime does not need to download the object.
|
65 |
"""
|
66 |
async with Imgur() as imgur:
|
67 |
+
oid, delete, link = await imgur.upload_image(self.image())
|
68 |
|
69 |
return UploadedObject(
|
70 |
name=self.name,
|
71 |
obj_id=oid,
|
72 |
deletehash=delete,
|
73 |
+
cached_obj=self.obj,
|
74 |
+
link=link
|
75 |
)
|
76 |
|
77 |
|
|
|
124 |
'\n'.join(o.name for o in self.pending.values())
|
125 |
)
|
126 |
raise Warning(msg)
|
127 |
+
|
128 |
+
return uploaded
|
129 |
|
130 |
async def get_obj(self, name: str) -> Any:
|
131 |
return await self.uploaded[name].download()
|
db/audio_db/is3/wrapper.py
CHANGED
@@ -11,7 +11,6 @@ from typing import Optional, Union, Tuple
|
|
11 |
from .utils import image_to_b64_string, bytes_to_image
|
12 |
|
13 |
dotenv.load_dotenv()
|
14 |
-
AUTH_HEADER = {'Authorization': f"Client-ID {os.getenv('IS3_CLIENT_ID')}"}
|
15 |
API_ENDPOINTS = {
|
16 |
'upload': 'https://api.imgur.com/3/upload/',
|
17 |
'download': 'http://i.imgur.com/',
|
@@ -26,7 +25,6 @@ async def get_tokens():
|
|
26 |
r = await session.request(
|
27 |
method='post',
|
28 |
url=API_ENDPOINTS['auth'],
|
29 |
-
headers=AUTH_HEADER,
|
30 |
data={
|
31 |
'refresh_token': os.getenv("IS3_REFRESH_TOKEN"),
|
32 |
'client_id': os.getenv("IS3_CLIENT_ID"),
|
@@ -38,6 +36,8 @@ async def get_tokens():
|
|
38 |
return r['access_token'], r['refresh_token']
|
39 |
|
40 |
ACCESS_TOKEN, REFRESH_TOKEN = asyncio.run(get_tokens())
|
|
|
|
|
41 |
|
42 |
|
43 |
|
@@ -73,7 +73,7 @@ class ImgurClient:
|
|
73 |
headers=AUTH_HEADER,
|
74 |
data={'image': data, 'type': 'base64'}
|
75 |
)
|
76 |
-
return r['id'], r['deletehash']
|
77 |
|
78 |
async def download_image(self, image_id: str) -> Image.Image:
|
79 |
"""Download the image and return the data as bytes."""
|
@@ -86,12 +86,3 @@ class ImgurClient:
|
|
86 |
"""Delete an image using a deletehash string"""
|
87 |
url = API_ENDPOINTS['delete'] + deletehash
|
88 |
await self._request('delete', url, headers=AUTH_HEADER)
|
89 |
-
|
90 |
-
|
91 |
-
async def get_token():
|
92 |
-
im = ImgurClient()
|
93 |
-
await im.get_access_token()
|
94 |
-
|
95 |
-
import asyncio
|
96 |
-
loop = asyncio.get_event_loop()
|
97 |
-
loop.run_until_complete(get_token())
|
|
|
11 |
from .utils import image_to_b64_string, bytes_to_image
|
12 |
|
13 |
dotenv.load_dotenv()
|
|
|
14 |
API_ENDPOINTS = {
|
15 |
'upload': 'https://api.imgur.com/3/upload/',
|
16 |
'download': 'http://i.imgur.com/',
|
|
|
25 |
r = await session.request(
|
26 |
method='post',
|
27 |
url=API_ENDPOINTS['auth'],
|
|
|
28 |
data={
|
29 |
'refresh_token': os.getenv("IS3_REFRESH_TOKEN"),
|
30 |
'client_id': os.getenv("IS3_CLIENT_ID"),
|
|
|
36 |
return r['access_token'], r['refresh_token']
|
37 |
|
38 |
ACCESS_TOKEN, REFRESH_TOKEN = asyncio.run(get_tokens())
|
39 |
+
# AUTH_HEADER = {'Authorization': f"Client-ID {os.getenv('IS3_CLIENT_ID')}"}
|
40 |
+
AUTH_HEADER = {'Authorization': f"Bearer {ACCESS_TOKEN}"}
|
41 |
|
42 |
|
43 |
|
|
|
73 |
headers=AUTH_HEADER,
|
74 |
data={'image': data, 'type': 'base64'}
|
75 |
)
|
76 |
+
return r['id'], r['deletehash'], r['link']
|
77 |
|
78 |
async def download_image(self, image_id: str) -> Image.Image:
|
79 |
"""Download the image and return the data as bytes."""
|
|
|
86 |
"""Delete an image using a deletehash string"""
|
87 |
url = API_ENDPOINTS['delete'] + deletehash
|
88 |
await self._request('delete', url, headers=AUTH_HEADER)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
db/generate_audio.py
CHANGED
@@ -1,13 +1,17 @@
|
|
1 |
-
|
|
|
2 |
from pymilvus import Collection
|
3 |
import asyncio
|
4 |
-
|
|
|
5 |
from db_connect import connect
|
6 |
|
7 |
# for audio storage
|
8 |
-
from audio_db.is3 import
|
9 |
|
10 |
connect()
|
|
|
|
|
11 |
|
12 |
async def generate_audio():
|
13 |
response = Collection("Response")
|
@@ -15,7 +19,8 @@ async def generate_audio():
|
|
15 |
|
16 |
data = []
|
17 |
|
18 |
-
response_iterator = response.query_iterator(
|
|
|
19 |
while True:
|
20 |
res = response_iterator.next()
|
21 |
if len(res) == 0:
|
@@ -24,22 +29,27 @@ async def generate_audio():
|
|
24 |
response_iterator.close()
|
25 |
break
|
26 |
|
27 |
-
bucket = is3.Bucket(str(res[0]['id']))
|
28 |
-
|
29 |
# generate audio
|
30 |
-
|
|
|
|
|
|
|
31 |
|
32 |
# store the audio
|
33 |
-
|
34 |
-
await
|
35 |
|
36 |
# save the audio record to AudioResponse
|
37 |
-
data.append([res[0]['text'],
|
|
|
38 |
|
39 |
audio_response.insert(list(zip(*data)))
|
40 |
audio_response.flush()
|
41 |
|
|
|
|
|
|
|
42 |
|
43 |
if __name__ == '__main__':
|
44 |
-
loop = asyncio.
|
45 |
-
loop.run_until_complete(generate_audio())
|
|
|
1 |
+
import io
|
2 |
+
|
3 |
from pymilvus import Collection
|
4 |
import asyncio
|
5 |
+
from bark import SAMPLE_RATE, generate_audio, preload_models
|
6 |
+
from scipy.io.wavfile import write
|
7 |
from db_connect import connect
|
8 |
|
9 |
# for audio storage
|
10 |
+
from audio_db.is3.is3 import StagedObject
|
11 |
|
12 |
connect()
|
13 |
+
preload_models()
|
14 |
+
|
15 |
|
16 |
async def generate_audio():
|
17 |
response = Collection("Response")
|
|
|
19 |
|
20 |
data = []
|
21 |
|
22 |
+
response_iterator = response.query_iterator(batch_size=1, output_fields=['text', 'embeddings'])
|
23 |
+
ids_to_delete = []
|
24 |
while True:
|
25 |
res = response_iterator.next()
|
26 |
if len(res) == 0:
|
|
|
29 |
response_iterator.close()
|
30 |
break
|
31 |
|
|
|
|
|
32 |
# generate audio
|
33 |
+
audio_array = generate_audio(res[0]['text'], history_prompt="en_speaker_3")
|
34 |
+
bytes_io = io.BytesIO()
|
35 |
+
write(bytes_io, SAMPLE_RATE, audio_array)
|
36 |
+
audio_bytes = bytes_io.read()
|
37 |
|
38 |
# store the audio
|
39 |
+
obj = StagedObject(obj=audio_bytes, name='audio')
|
40 |
+
uploaded_object = await obj.upload()
|
41 |
|
42 |
# save the audio record to AudioResponse
|
43 |
+
data.append([res[0]['text'], uploaded_object.obj_id, res[0]['embeddings']])
|
44 |
+
ids_to_delete.append(res[0]['id'])
|
45 |
|
46 |
audio_response.insert(list(zip(*data)))
|
47 |
audio_response.flush()
|
48 |
|
49 |
+
# delete text to generate audio
|
50 |
+
response.delete(expr=f"id in {str(ids_to_delete)}")
|
51 |
+
|
52 |
|
53 |
if __name__ == '__main__':
|
54 |
+
loop = asyncio.new_event_loop()
|
55 |
+
loop.run_until_complete(generate_audio())
|