import os import numpy as np import pandas as pd from datasets import Dataset, DownloadMode, load_dataset from gradio_client import Client from src.my_logger import setup_logger SUBREDDIT = os.environ["SUBREDDIT"] USERNAME = os.environ["USERNAME"] OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}" PROCESSED_DATASET = os.environ['PROCESSED_DATASET'] embeddings_space = f"{USERNAME}/nomic-embeddings" logger = setup_logger(__name__) def load_datasets(): # Get latest datasets locally logger.info(f"Trying to download {PROCESSED_DATASET}") dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD) logger.info(f"Loaded {PROCESSED_DATASET}") logger.info(f"Trying to download {OG_DATASET}") original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD) logger.info(f"Loaded {OG_DATASET}") return dataset, original_dataset def merge_and_update_datasets(dataset, original_dataset): # Get client client = Client(embeddings_space) # Merge and figure out which rows need to be updated with embeddings odf = original_dataset['train'].to_pandas() df = dataset['train'].to_pandas() # Step 1: Merge df onto odf # We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding' merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', '')) updated_row_count = len(merged_df[merged_df.content != merged_df.content_odf]) # Step 2: Compare 'content' from odf and df, update 'embedding' if they differ merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding']) # Step 3: Cleanup - keep only the necessary columns. # Assuming you want to keep 'content' from 'odf' and the updated 'embedding', and drop the rest merged_df = merged_df.drop(columns=['content', 'new', 'updated']) # Update columns to match df merged_df.rename(columns={'content_odf': 'content'}, inplace=True) # Rename 'content_odf' back to 'content' logger.info(f"Updating {updated_row_count} rows...") # Iterate over the DataFrame rows where 'embedding' is None for index, row in merged_df[merged_df['embedding'].isnull()].iterrows(): # Update 'embedding' for the current row using our function merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client) dataset['train'] = Dataset.from_pandas(merged_df) logger.info(f"Updated {updated_row_count} rows") return dataset, updated_row_count def update_embeddings(content, client): embedding = client.predict('search_document: ' + content, api_name="/embed") return np.array(embedding)