Daryl Fung commited on
Commit
9a3c8f5
1 Parent(s): 932db78

added generation audio

Browse files
db/audio_db/is3/is3.py CHANGED
@@ -23,6 +23,7 @@ class UploadedObject(BaseModel):
23
  obj_id: ObjectId
24
  deletehash: str
25
  cached_obj: Any = None
 
26
 
27
  def __getstate__(self):
28
  d = super().__getstate__()
@@ -63,13 +64,14 @@ class StagedObject(BaseModel):
63
  during the same runtime does not need to download the object.
64
  """
65
  async with Imgur() as imgur:
66
- oid, delete = await imgur.upload_image(self.image())
67
 
68
  return UploadedObject(
69
  name=self.name,
70
  obj_id=oid,
71
  deletehash=delete,
72
- cached_obj=self.obj
 
73
  )
74
 
75
 
@@ -122,6 +124,8 @@ class Bucket:
122
  '\n'.join(o.name for o in self.pending.values())
123
  )
124
  raise Warning(msg)
 
 
125
 
126
  async def get_obj(self, name: str) -> Any:
127
  return await self.uploaded[name].download()
 
23
  obj_id: ObjectId
24
  deletehash: str
25
  cached_obj: Any = None
26
+ link: str
27
 
28
  def __getstate__(self):
29
  d = super().__getstate__()
 
64
  during the same runtime does not need to download the object.
65
  """
66
  async with Imgur() as imgur:
67
+ oid, delete, link = await imgur.upload_image(self.image())
68
 
69
  return UploadedObject(
70
  name=self.name,
71
  obj_id=oid,
72
  deletehash=delete,
73
+ cached_obj=self.obj,
74
+ link=link
75
  )
76
 
77
 
 
124
  '\n'.join(o.name for o in self.pending.values())
125
  )
126
  raise Warning(msg)
127
+
128
+ return uploaded
129
 
130
  async def get_obj(self, name: str) -> Any:
131
  return await self.uploaded[name].download()
db/audio_db/is3/wrapper.py CHANGED
@@ -11,7 +11,6 @@ from typing import Optional, Union, Tuple
11
  from .utils import image_to_b64_string, bytes_to_image
12
 
13
  dotenv.load_dotenv()
14
- AUTH_HEADER = {'Authorization': f"Client-ID {os.getenv('IS3_CLIENT_ID')}"}
15
  API_ENDPOINTS = {
16
  'upload': 'https://api.imgur.com/3/upload/',
17
  'download': 'http://i.imgur.com/',
@@ -26,7 +25,6 @@ async def get_tokens():
26
  r = await session.request(
27
  method='post',
28
  url=API_ENDPOINTS['auth'],
29
- headers=AUTH_HEADER,
30
  data={
31
  'refresh_token': os.getenv("IS3_REFRESH_TOKEN"),
32
  'client_id': os.getenv("IS3_CLIENT_ID"),
@@ -38,6 +36,8 @@ async def get_tokens():
38
  return r['access_token'], r['refresh_token']
39
 
40
  ACCESS_TOKEN, REFRESH_TOKEN = asyncio.run(get_tokens())
 
 
41
 
42
 
43
 
@@ -73,7 +73,7 @@ class ImgurClient:
73
  headers=AUTH_HEADER,
74
  data={'image': data, 'type': 'base64'}
75
  )
76
- return r['id'], r['deletehash']
77
 
78
  async def download_image(self, image_id: str) -> Image.Image:
79
  """Download the image and return the data as bytes."""
@@ -86,12 +86,3 @@ class ImgurClient:
86
  """Delete an image using a deletehash string"""
87
  url = API_ENDPOINTS['delete'] + deletehash
88
  await self._request('delete', url, headers=AUTH_HEADER)
89
-
90
-
91
- async def get_token():
92
- im = ImgurClient()
93
- await im.get_access_token()
94
-
95
- import asyncio
96
- loop = asyncio.get_event_loop()
97
- loop.run_until_complete(get_token())
 
11
  from .utils import image_to_b64_string, bytes_to_image
12
 
13
  dotenv.load_dotenv()
 
14
  API_ENDPOINTS = {
15
  'upload': 'https://api.imgur.com/3/upload/',
16
  'download': 'http://i.imgur.com/',
 
25
  r = await session.request(
26
  method='post',
27
  url=API_ENDPOINTS['auth'],
 
28
  data={
29
  'refresh_token': os.getenv("IS3_REFRESH_TOKEN"),
30
  'client_id': os.getenv("IS3_CLIENT_ID"),
 
36
  return r['access_token'], r['refresh_token']
37
 
38
  ACCESS_TOKEN, REFRESH_TOKEN = asyncio.run(get_tokens())
39
+ # AUTH_HEADER = {'Authorization': f"Client-ID {os.getenv('IS3_CLIENT_ID')}"}
40
+ AUTH_HEADER = {'Authorization': f"Bearer {ACCESS_TOKEN}"}
41
 
42
 
43
 
 
73
  headers=AUTH_HEADER,
74
  data={'image': data, 'type': 'base64'}
75
  )
76
+ return r['id'], r['deletehash'], r['link']
77
 
78
  async def download_image(self, image_id: str) -> Image.Image:
79
  """Download the image and return the data as bytes."""
 
86
  """Delete an image using a deletehash string"""
87
  url = API_ENDPOINTS['delete'] + deletehash
88
  await self._request('delete', url, headers=AUTH_HEADER)
 
 
 
 
 
 
 
 
 
db/generate_audio.py CHANGED
@@ -1,13 +1,17 @@
1
- from sentence_transformers import SentenceTransformer
 
2
  from pymilvus import Collection
3
  import asyncio
4
-
 
5
  from db_connect import connect
6
 
7
  # for audio storage
8
- from audio_db.is3 import is3
9
 
10
  connect()
 
 
11
 
12
  async def generate_audio():
13
  response = Collection("Response")
@@ -15,7 +19,8 @@ async def generate_audio():
15
 
16
  data = []
17
 
18
- response_iterator = response.query_iterator(limit=1, output_fields=['text', 'embeddings'])
 
19
  while True:
20
  res = response_iterator.next()
21
  if len(res) == 0:
@@ -24,22 +29,27 @@ async def generate_audio():
24
  response_iterator.close()
25
  break
26
 
27
- bucket = is3.Bucket(str(res[0]['id']))
28
-
29
  # generate audio
30
- audio_bytes = open('445766006129375465.wav', 'rb').read()
 
 
 
31
 
32
  # store the audio
33
- bucket.stage_obj(audio_bytes, 'audio')
34
- await bucket.commit()
35
 
36
  # save the audio record to AudioResponse
37
- data.append([res[0]['text'], str(res[0]['id']), res[0]['embeddings']])
 
38
 
39
  audio_response.insert(list(zip(*data)))
40
  audio_response.flush()
41
 
 
 
 
42
 
43
  if __name__ == '__main__':
44
- loop = asyncio.get_event_loop()
45
- loop.run_until_complete(generate_audio())
 
1
+ import io
2
+
3
  from pymilvus import Collection
4
  import asyncio
5
+ from bark import SAMPLE_RATE, generate_audio, preload_models
6
+ from scipy.io.wavfile import write
7
  from db_connect import connect
8
 
9
  # for audio storage
10
+ from audio_db.is3.is3 import StagedObject
11
 
12
  connect()
13
+ preload_models()
14
+
15
 
16
  async def generate_audio():
17
  response = Collection("Response")
 
19
 
20
  data = []
21
 
22
+ response_iterator = response.query_iterator(batch_size=1, output_fields=['text', 'embeddings'])
23
+ ids_to_delete = []
24
  while True:
25
  res = response_iterator.next()
26
  if len(res) == 0:
 
29
  response_iterator.close()
30
  break
31
 
 
 
32
  # generate audio
33
+ audio_array = generate_audio(res[0]['text'], history_prompt="en_speaker_3")
34
+ bytes_io = io.BytesIO()
35
+ write(bytes_io, SAMPLE_RATE, audio_array)
36
+ audio_bytes = bytes_io.read()
37
 
38
  # store the audio
39
+ obj = StagedObject(obj=audio_bytes, name='audio')
40
+ uploaded_object = await obj.upload()
41
 
42
  # save the audio record to AudioResponse
43
+ data.append([res[0]['text'], uploaded_object.obj_id, res[0]['embeddings']])
44
+ ids_to_delete.append(res[0]['id'])
45
 
46
  audio_response.insert(list(zip(*data)))
47
  audio_response.flush()
48
 
49
+ # delete text to generate audio
50
+ response.delete(expr=f"id in {str(ids_to_delete)}")
51
+
52
 
53
  if __name__ == '__main__':
54
+ loop = asyncio.new_event_loop()
55
+ loop.run_until_complete(generate_audio())