Spaces:
Runtime error
Runtime error
File size: 2,764 Bytes
9ae1b66 3d12d3a 65f2fab 9ae1b66 ef9cbc8 9ae1b66 9930cd7 9ae1b66 9930cd7 9ae1b66 9930cd7 9ae1b66 9930cd7 9ae1b66 3772eaf 65f2fab 3772eaf 9ae1b66 ef9cbc8 9ae1b66 ef9cbc8 9ae1b66 3772eaf 9ae1b66 ef9cbc8 9ae1b66 3772eaf a6deb48 9ae1b66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import os
import numpy as np
import pandas as pd
from datasets import Dataset, DownloadMode, load_dataset
from gradio_client import Client
from src.my_logger import setup_logger
SUBREDDIT = os.environ["SUBREDDIT"]
USERNAME = os.environ["USERNAME"]
OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
embeddings_space = f"{USERNAME}/nomic-embeddings"
logger = setup_logger(__name__)
def load_datasets():
# Get latest datasets locally
logger.info(f"Trying to download {PROCESSED_DATASET}")
dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
logger.info(f"Loaded {PROCESSED_DATASET}")
logger.info(f"Trying to download {OG_DATASET}")
original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
logger.info(f"Loaded {OG_DATASET}")
return dataset, original_dataset
def merge_and_update_datasets(dataset, original_dataset):
# Get client
client = Client(embeddings_space)
# Merge and figure out which rows need to be updated with embeddings
odf = original_dataset['train'].to_pandas()
df = dataset['train'].to_pandas()
# Step 1: Merge df onto odf
# We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding'
merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', ''))
updated_row_count = len(merged_df[merged_df.content != merged_df.content_odf])
# Step 2: Compare 'content' from odf and df, update 'embedding' if they differ
merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding'])
# Step 3: Cleanup - keep only the necessary columns.
# Assuming you want to keep 'content' from 'odf' and the updated 'embedding', and drop the rest
merged_df = merged_df.drop(columns=['content', 'new', 'updated']) # Update columns to match df
merged_df.rename(columns={'content_odf': 'content'}, inplace=True) # Rename 'content_odf' back to 'content'
logger.info(f"Updating {updated_row_count} rows...")
# Iterate over the DataFrame rows where 'embedding' is None
for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
# Update 'embedding' for the current row using our function
merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client)
dataset['train'] = Dataset.from_pandas(merged_df)
logger.info(f"Updated {updated_row_count} rows")
return dataset, updated_row_count
def update_embeddings(content, client):
embedding = client.predict('search_document: ' + content, api_name="/embed")
return np.array(embedding)
|