text-to-speech / db /query_db.py
Daryl Fung
finalize mvp
8d83939
raw
history blame
1.08 kB
from sentence_transformers import SentenceTransformer
# Connect using a MilvusClient object
from pymilvus import Collection
import random
from .audio_db.is3.is3 import UploadedObject
from .db_connect import connect, disconnect
connect()
async def query(embeddings, threshold=0.8):
audio_response = Collection("AudioResponse")
search_results = audio_response.search(data=embeddings, limit=5, anns_field="embeddings", param={
"metric_type": "COSINE",
"index_type": "IVF_FLAT",
"params": {"nlist": 1024}
})[0]
similar_indexes = [index for index, value in enumerate(search_results.distances) if value > threshold]
if len(similar_indexes) > 0:
selected_index = random.choice(similar_indexes)
selected_id = search_results.ids[selected_index]
audio_obj = audio_response.query(f'id == {selected_id}', output_fields=['text', 'filename'])[0]
audio_id = audio_obj['filename']
audio_bytes = await UploadedObject(obj_id=audio_id).download()
return audio_bytes
return None
disconnect()