derek-thomas HF staff commited on
Commit
3772eaf
1 Parent(s): 7d5ff0e

Move client instantiation

Browse files
Files changed (1) hide show
  1. src/utilities.py +5 -3
src/utilities.py CHANGED
@@ -12,7 +12,6 @@ USERNAME = os.environ["USERNAME"]
12
  OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
13
  PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
14
 
15
- client = Client("derek-thomas/nomic-embeddings")
16
  logger = setup_logger(__name__)
17
 
18
 
@@ -29,6 +28,9 @@ async def load_datasets():
29
 
30
 
31
  def merge_and_update_datasets(dataset, original_dataset):
 
 
 
32
  # Merge and figure out which rows need to be updated with embeddings
33
  odf = original_dataset['train'].to_pandas()
34
  df = dataset['train'].to_pandas()
@@ -50,13 +52,13 @@ def merge_and_update_datasets(dataset, original_dataset):
50
  # Iterate over the DataFrame rows where 'embedding' is None
51
  for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
52
  # Update 'embedding' for the current row using our function
53
- merged_df.at[index, 'embedding'] = update_embeddings(row['content'])
54
 
55
  dataset['train'] = Dataset.from_pandas(merged_df)
56
  logger.info(f"Updated {updated_rows} rows")
57
  return dataset
58
 
59
 
60
- def update_embeddings(content):
61
  embedding = client.predict(content, api_name="/embed")
62
  return np.array(embedding)
 
12
  OG_DATASET = f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
13
  PROCESSED_DATASET = os.environ['PROCESSED_DATASET']
14
 
 
15
  logger = setup_logger(__name__)
16
 
17
 
 
28
 
29
 
30
  def merge_and_update_datasets(dataset, original_dataset):
31
+ # Get client
32
+ client = Client("derek-thomas/nomic-embeddings")
33
+
34
  # Merge and figure out which rows need to be updated with embeddings
35
  odf = original_dataset['train'].to_pandas()
36
  df = dataset['train'].to_pandas()
 
52
  # Iterate over the DataFrame rows where 'embedding' is None
53
  for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
54
  # Update 'embedding' for the current row using our function
55
+ merged_df.at[index, 'embedding'] = update_embeddings(content=row['content'], client=client)
56
 
57
  dataset['train'] = Dataset.from_pandas(merged_df)
58
  logger.info(f"Updated {updated_rows} rows")
59
  return dataset
60
 
61
 
62
+ def update_embeddings(content, client):
63
  embedding = client.predict(content, api_name="/embed")
64
  return np.array(embedding)