File size: 3,705 Bytes
3e2784f
 
 
 
 
 
 
 
 
 
 
4d0e134
3e2784f
4d0e134
3e2784f
 
 
 
 
6a4b44c
 
4d0e134
6a4b44c
4d0e134
 
3e2784f
 
 
4d0e134
3e2784f
4d0e134
3e2784f
 
 
 
 
 
 
4d0e134
3e2784f
4d0e134
3e2784f
4d0e134
 
 
3e2784f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d0e134
 
 
6a4b44c
3e2784f
4d0e134
 
 
3e2784f
 
 
 
 
 
 
 
 
 
4d0e134
3e2784f
 
 
 
 
 
 
4d0e134
3e2784f
 
 
 
 
4d0e134
3e2784f
 
4d0e134
 
 
 
 
3e2784f
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import asyncio
import logging

import chromadb
import requests
import stamina
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from huggingface_hub import InferenceClient
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import thread_map


from prep_viewer_data import prep_data
from utils import get_chroma_client

# Set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

EMBEDDING_MODEL_NAME = "davanstrien/query-to-dataset-viewer-descriptions"
EMBEDDING_MODEL_REVISION = "07c71d97861a73695f0c53cd6b4b32980007d908"
INFERENCE_MODEL_URL = (
    "https://ecg0by60w2vo9j8h.us-east-1.aws.endpoints.huggingface.cloud"
)


def initialize_clients():
    logger.info("Initializing clients")
    chroma_client = get_chroma_client()
    inference_client = InferenceClient(
        INFERENCE_MODEL_URL,
    )
    return chroma_client, inference_client


def create_collection(chroma_client):
    logger.info("Creating or getting collection")
    embedding_function = SentenceTransformerEmbeddingFunction(
        model_name=EMBEDDING_MODEL_NAME,
        trust_remote_code=True,
        revision=EMBEDDING_MODEL_REVISION,
    )
    logger.info(f"Embedding function: {embedding_function}")
    logger.info(f"Embedding model name: {EMBEDDING_MODEL_NAME}")
    logger.info(f"Embedding model revision: {EMBEDDING_MODEL_REVISION}")
    return chroma_client.create_collection(
        name="dataset-viewer-descriptions",
        get_or_create=True,
        embedding_function=embedding_function,
        metadata={"hnsw:space": "cosine"},
    )


@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
def embed_card(text, client):
    text = text[:8192]
    return client.feature_extraction(text)


def embed_and_upsert_datasets(
    dataset_rows_and_ids: list[dict[str, str]],
    collection: chromadb.Collection,
    inference_client: InferenceClient,
    batch_size: int = 100,
):
    logger.info(
        f"Embedding and upserting {len(dataset_rows_and_ids)} datasets for viewer data"
    )
    for i in tqdm(range(0, len(dataset_rows_and_ids), batch_size)):
        batch = dataset_rows_and_ids[i : i + batch_size]
        ids = []
        documents = []
        for item in batch:
            ids.append(item["dataset_id"])
            documents.append(f"HUB_DATASET_PREVIEW: {item['formatted_prompt']}")
        results = thread_map(
            lambda doc: embed_card(doc, inference_client), documents, leave=False
        )
        logger.info(f"Results: {len(results)}")
        collection.upsert(
            ids=ids,
            embeddings=[embedding.tolist()[0] for embedding in results],
        )
        logger.debug(f"Processed batch {i//batch_size + 1}")


async def refresh_viewer_data(sample_size=200_000, min_likes=2):
    logger.info(
        f"Refreshing viewer data with sample_size={sample_size} and min_likes={min_likes}"
    )
    chroma_client, inference_client = initialize_clients()
    collection = create_collection(chroma_client)
    logger.info("Collection created successfully")
    logger.info("Preparing data")
    df = await prep_data(sample_size=sample_size, min_likes=min_likes)
    df.write_parquet("viewer_data.parquet")
    if df is not None:
        logger.info("Data prepared successfully")
        logger.info(f"Data: {df}")

    dataset_rows_and_ids = df.to_dicts()

    logger.info(f"Embedding and upserting {len(dataset_rows_and_ids)} datasets")
    embed_and_upsert_datasets(dataset_rows_and_ids, collection, inference_client)
    logger.info("Refresh completed successfully")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    asyncio.run(refresh_viewer_data())