File size: 2,764 Bytes
9ae1b66
 
 
 
 
 
 
 
 
 
 
3d12d3a
 
65f2fab
9ae1b66
 
 
 
ef9cbc8
9ae1b66
9930cd7
9ae1b66
9930cd7
9ae1b66
9930cd7
9ae1b66
9930cd7
9ae1b66
 
 
 
3772eaf
65f2fab
3772eaf
9ae1b66
 
 
 
 
 
 
ef9cbc8
9ae1b66
 
 
 
 
 
 
 
 
ef9cbc8
9ae1b66
 
 
3772eaf
9ae1b66
 
ef9cbc8
 
9ae1b66
 
3772eaf
a6deb48
9ae1b66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os

import numpy as np
import pandas as pd
from datasets import Dataset, DownloadMode, load_dataset
from gradio_client import Client

from src.my_logger import setup_logger

SUBREDDIT = os.environ["SUBREDDIT"]
USERNAME = os.environ["USERNAME"]
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
embeddings_space = f"{USERNAME}/nomic-embeddings"

logger = setup_logger(__name__)


def load_datasets():
    # Get latest datasets locally
    logger.info(f"Trying to download {PROCESSED_DATASET}")
    dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
    logger.info(f"Loaded {PROCESSED_DATASET}")

    logger.info(f"Trying to download {OG_DATASET}")
    original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
    logger.info(f"Loaded {OG_DATASET}")
    return dataset, original_dataset


def merge_and_update_datasets(dataset, original_dataset):
    # Get client
    client = Client(embeddings_space)

    # Merge and figure out which rows need to be updated with embeddings
    odf = original_dataset['train'].to_pandas()
    df = dataset['train'].to_pandas()

    # Step 1: Merge df onto odf
    # We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding'
    merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', ''))
    updated_row_count = len(merged_df[merged_df.content != merged_df.content_odf])

    # Step 2: Compare 'content' from odf and df, update 'embedding' if they differ
    merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding'])

    # Step 3: Cleanup - keep only the necessary columns.
    # Assuming you want to keep 'content' from 'odf' and the updated 'embedding', and drop the rest
    merged_df = merged_df.drop(columns=['content', 'new', 'updated'])  # Update columns to match df
    merged_df.rename(columns={'content_odf': 'content'}, inplace=True)  # Rename 'content_odf' back to 'content'

    logger.info(f"Updating {updated_row_count} rows...")
    # Iterate over the DataFrame rows where 'embedding' is None
    for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
        # Update 'embedding' for the current row using our function
        merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client)

    dataset['train'] = Dataset.from_pandas(merged_df)
    logger.info(f"Updated {updated_row_count} rows")
    return dataset, updated_row_count


def update_embeddings(content, client):
    embedding = client.predict('search_document: ' + content, api_name="/embed")
    return np.array(embedding)