charlieoneill commited on
Commit
d6eab4f
1 Parent(s): 3187d23

feature families

Browse files
Files changed (2) hide show
  1. .gitignore +3 -1
  2. app.py +382 -27
.gitignore CHANGED
@@ -1 +1,3 @@
1
- data/
 
 
 
1
+ data/
2
+ __pycache__
3
+ __pycache__/
app.py CHANGED
@@ -1,3 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import json
@@ -11,6 +140,10 @@ import plotly.express as px
11
  from collections import Counter
12
  from huggingface_hub import hf_hub_download
13
  import os
 
 
 
 
14
 
15
  import os
16
  print(os.getenv('MODEL_REPO_ID'))
@@ -44,7 +177,15 @@ def download_all_files():
44
  # "csLG_clean_families_64_9216.json",
45
  # "astroPH_clean_families_64_9216.json",
46
  "astroPH_family_analysis_64_9216.json",
47
- "csLG_family_analysis_64_9216.json"
 
 
 
 
 
 
 
 
48
  ]
49
 
50
  for file in files_to_download:
@@ -74,9 +215,13 @@ def load_subject_data(subject):
74
  feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
75
  metadata_path = f'data/{subject}_paper_metadata.csv'
76
  topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
 
77
  topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
78
  families_path = f"data/{subject}_clean_families_{k}_{n_dirs}.json"
79
  family_analysis_path = f"data/{subject}_family_analysis_{k}_{n_dirs}.json"
 
 
 
80
 
81
  abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
82
  with open(texts_path, 'r') as f:
@@ -86,6 +231,7 @@ def load_subject_data(subject):
86
  df_metadata = pd.read_csv(metadata_path)
87
  topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
88
  topk_values = np.load(topk_values_path).astype(np.float32)
 
89
 
90
  model_filename = f"{subject}_64_9216.pth"
91
  model_path = os.path.join("data", model_filename)
@@ -109,6 +255,9 @@ def load_subject_data(subject):
109
  'df_metadata': df_metadata,
110
  'topk_indices': topk_indices,
111
  'topk_values': topk_values,
 
 
 
112
  'ae': ae,
113
  'decoder': decoder,
114
  # 'feature_families': feature_families,
@@ -163,13 +312,15 @@ def get_feature_activations(subject, feature_index, m=5, min_length=100):
163
 
164
  def calculate_co_occurrences(subject, target_index, n_features=9216):
165
  topk_indices = subject_data[subject]['topk_indices']
 
166
 
167
  mask = np.any(topk_indices == target_index, axis=1)
168
  co_occurring_indices = topk_indices[mask].flatten()
169
  co_occurrences = Counter(co_occurring_indices)
170
  del co_occurrences[target_index]
171
- result = np.zeros(n_features, dtype=int)
172
  result[list(co_occurrences.keys())] = list(co_occurrences.values())
 
173
  return result
174
 
175
  def style_dataframe(df: pd.DataFrame, is_top: bool) -> pd.DataFrame:
@@ -291,10 +442,175 @@ def visualize_feature(subject, index):
291
  "Co-occurrences": topk_values_co_occurrence
292
  })
293
  df_co_occurrences_styled = df_co_occurrences.style.format({
294
- "Co-occurrences": "{:.0f}" # Keep as integer
295
  })
296
 
297
- return output, styled_top_abstracts, df_top_correlated_styled, df_bottom_correlated_styled, df_co_occurrences_styled, fig2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  # Modify the main interface function
300
  def create_interface():
@@ -453,7 +769,10 @@ def create_interface():
453
  def search_feature_labels(search_text):
454
  if not search_text:
455
  return gr.CheckboxGroup(choices=[])
456
- matches = [f"{f['label']} ({f['index']})" for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
 
 
 
457
  return gr.CheckboxGroup(choices=matches[:10])
458
 
459
  feature_search.change(search_feature_labels, inputs=[feature_search], outputs=[feature_matches])
@@ -536,24 +855,24 @@ def create_interface():
536
  wrap=True
537
  )
538
 
539
- gr.Markdown("## Correlated Features")
540
  with gr.Row():
541
  with gr.Column(scale=1):
542
- gr.Markdown("### Top 5 Correlated Features")
543
- top_correlated = gr.Dataframe(
544
- headers=["Feature", "Cosine similarity"],
545
  interactive=False
546
  )
547
  with gr.Column(scale=1):
548
- gr.Markdown("### Bottom 5 Correlated Features")
549
- bottom_correlated = gr.Dataframe(
550
- headers=["Feature", "Cosine similarity"],
551
  interactive=False
552
  )
553
-
554
  with gr.Row():
555
  with gr.Column(scale=1):
556
- gr.Markdown("## Top 5 Co-occurring Features")
557
  co_occurring_features = gr.Dataframe(
558
  headers=["Feature", "Co-occurrences"],
559
  interactive=False
@@ -562,10 +881,31 @@ def create_interface():
562
  gr.Markdown(f"## Activation Value Distribution")
563
  activation_dist = gr.Plot()
564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  def search_feature_labels(search_text, current_subject):
566
  if not search_text:
567
  return gr.CheckboxGroup(choices=[])
568
- matches = [f"{f['label']} ({f['index']})" for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
 
 
 
569
  return gr.CheckboxGroup(choices=matches[:10])
570
 
571
  feature_search.change(search_feature_labels, inputs=[feature_search, subject], outputs=[feature_matches])
@@ -576,15 +916,15 @@ def create_interface():
576
 
577
  # Extract the feature index from the selected feature string
578
  feature_index = int(selected_features[0].split('(')[-1].strip(')'))
579
- feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist = visualize_feature(current_subject, feature_index)
580
 
581
  # Return the visualization results along with empty values for search box and checkbox
582
- return feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, "", []
583
 
584
  visualize_button.click(
585
  on_visualize,
586
  inputs=[feature_matches, subject],
587
- outputs=[feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, feature_search, feature_matches]
588
  )
589
 
590
  with gr.Tab("Feature Families"):
@@ -595,19 +935,26 @@ def create_interface():
595
  family_matches = gr.CheckboxGroup(label="Matching Feature Families", choices=[])
596
  visualize_family_button = gr.Button("Visualize Feature Family")
597
 
598
- family_info = gr.Markdown()
599
  family_dataframe = gr.Dataframe(
600
- headers=["Feature", "F1 Score", "Pearson Correlation"],
601
- datatype=["markdown", "number", "number"],
602
  label="Family and Child Features"
603
  )
604
 
 
 
 
 
605
 
606
  def search_feature_families(search_text, current_subject):
607
  family_analysis = subject_data[current_subject]['family_analysis']
608
  if not search_text:
609
  return gr.CheckboxGroup(choices=[])
610
- matches = [family['superfeature'] for family in family_analysis if search_text.lower() in family['superfeature'].lower()]
 
 
 
611
  return gr.CheckboxGroup(choices=matches[:10]) # Limit to top 10 matches
612
 
613
  def visualize_feature_family(selected_families, current_subject):
@@ -627,16 +974,20 @@ def create_interface():
627
  df_data = [
628
  {
629
  "Feature": f"## {family_data['superfeature']}",
 
630
  "F1 Score": round(family_data['family_f1'], 2),
631
- "Pearson Correlation": round(family_data['family_pearson'], 4)
632
  },
633
  ]
634
 
635
- for name, f1, pearson in zip(family_data['feature_names'], family_data['feature_f1'], family_data['feature_pearson']):
 
 
636
  df_data.append({
637
  "Feature": name,
 
638
  "F1 Score": round(f1, 2),
639
- "Pearson Correlation": round(pearson, 4)
640
  })
641
 
642
  df = pd.DataFrame(df_data)
@@ -645,13 +996,17 @@ def create_interface():
645
  output += "## Super Reasoning\n"
646
  output += f"{family_data['super_reasoning']}\n\n"
647
 
648
- return output, df, "", [] # Return empty string for search box and empty list for checkbox
 
 
 
649
 
650
  family_search.change(search_feature_families, inputs=[family_search, subject], outputs=[family_matches])
651
  visualize_family_button.click(
652
  visualize_feature_family,
653
  inputs=[family_matches, subject],
654
- outputs=[family_info, family_dataframe, family_search, family_matches]
 
655
  )
656
 
657
 
 
1
+ # import gradio as gr
2
+ # import numpy as np
3
+ # import json
4
+ # import pandas as pd
5
+ # from openai import OpenAI
6
+ # import yaml
7
+ # from typing import Optional, List, Dict, Tuple, Any
8
+ # from topk_sae import FastAutoencoder
9
+ # import torch
10
+ # import plotly.express as px
11
+ # from collections import Counter
12
+ # from huggingface_hub import hf_hub_download
13
+ # import os
14
+ # import networkx as nx
15
+ # import plotly.graph_objs as go
16
+ # from ast import literal_eval as make_tuple
17
+ # import random
18
+
19
+ # import os
20
+ # print(os.getenv('MODEL_REPO_ID'))
21
+
22
+ # # Constants
23
+ # EMBEDDING_MODEL = "text-embedding-3-small"
24
+ # d_model = 1536
25
+ # n_dirs = d_model * 6
26
+ # k = 64
27
+ # auxk = 128
28
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
+ # torch.set_grad_enabled(False)
30
+
31
+ # # Function to download all necessary files
32
+ # def download_all_files():
33
+ # files_to_download = [
34
+ # "astroPH_paper_metadata.csv",
35
+ # "csLG_feature_analysis_results_64.json",
36
+ # "astroPH_topk_indices_64_9216_int32.npy",
37
+ # "astroPH_64_9216.pth",
38
+ # "astroPH_topk_values_64_9216_float16.npy",
39
+ # "csLG_abstract_texts.json",
40
+ # "csLG_topk_values_64_9216_float16.npy",
41
+ # "csLG_abstract_embeddings_float16.npy",
42
+ # "csLG_paper_metadata.csv",
43
+ # "csLG_64_9216.pth",
44
+ # "astroPH_abstract_texts.json",
45
+ # "astroPH_feature_analysis_results_64.json",
46
+ # "csLG_topk_indices_64_9216_int32.npy",
47
+ # "astroPH_abstract_embeddings_float16.npy",
48
+ # # "csLG_clean_families_64_9216.json",
49
+ # # "astroPH_clean_families_64_9216.json",
50
+ # # "astroPH_family_analysis_64_9216.json",
51
+ # "csLG_family_analysis_64_9216.json"
52
+ # ]
53
+
54
+ # for file in files_to_download:
55
+ # local_path = os.path.join("data", file)
56
+ # os.makedirs(os.path.dirname(local_path), exist_ok=True)
57
+ # hf_hub_download(repo_id="charlieoneill/saerch-ai-data", filename=file, local_dir="data")
58
+ # print(f"Downloaded {file}")
59
+
60
+ # # Load configuration and initialize OpenAI client
61
+ # download_all_files()
62
+
63
+ # # Load the API key from the environment variable
64
+ # api_key = os.getenv('openai_key')
65
+
66
+ # # Ensure the API key is set
67
+ # if not api_key:
68
+ # raise ValueError("The environment variable 'openai_key' is not set.")
69
+
70
+ # # Initialize the OpenAI client with the API key
71
+ # client = OpenAI(api_key=api_key)
72
+
73
+ # # Function to load data for a specific subject
74
+ # def load_subject_data(subject):
75
+
76
+ # embeddings_path = f"data/{subject}_abstract_embeddings_float16.npy"
77
+ # texts_path = f"data/{subject}_abstract_texts.json"
78
+ # feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
79
+ # metadata_path = f'data/{subject}_paper_metadata.csv'
80
+ # topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
81
+ # norms_path = f"data/{subject}_norms_{k}_{n_dirs}.npy"
82
+ # topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
83
+ # families_path = f"data/{subject}_clean_families_{k}_{n_dirs}.json"
84
+ # family_analysis_path = f"data/{subject}_family_analysis_{k}_{n_dirs}.json"
85
+ # nns_32to64 = json.load(open(f"data/{subject}_nns_32to64.json"))
86
+ # nns_16to32 = json.load(open(f"data/{subject}_nns_16to32.json"))
87
+ # nns_16to64 = json.load(open(f"data/{subject}_nns_16to64.json"))
88
+
89
+ # abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
90
+ # with open(texts_path, 'r') as f:
91
+ # abstract_texts = json.load(f)
92
+ # with open(feature_analysis_path, 'r') as f:
93
+ # feature_analysis = json.load(f)
94
+ # df_metadata = pd.read_csv(metadata_path)
95
+ # topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
96
+ # topk_values = np.load(topk_values_path).astype(np.float32)
97
+ # norms = np.load(norms_path).astype(np.float32)
98
+
99
+ # model_filename = f"{subject}_64_9216.pth"
100
+ # model_path = os.path.join("data", model_filename)
101
+
102
+ # ae = FastAutoencoder(n_dirs, d_model, k, auxk, multik=0).to(device)
103
+ # ae.load_state_dict(torch.load(model_path))
104
+ # ae.eval()
105
+
106
+ # weights = torch.load(model_path)
107
+ # decoder = weights['decoder.weight'].cpu().numpy()
108
+ # del weights
109
+
110
+ # with open(family_analysis_path, 'r') as f:
111
+ # family_analysis = json.load(f)
112
+
113
+
114
+ # return {
115
+ # 'abstract_embeddings': abstract_embeddings,
116
+ # 'abstract_texts': abstract_texts,
117
+ # 'feature_analysis': feature_analysis,
118
+ # 'df_metadata': df_metadata,
119
+ # 'topk_indices': topk_indices,
120
+ # 'topk_values': topk_values,
121
+ # 'norms': norms,
122
+ # 'nns_32to64': nns_32to64,
123
+ # 'nns_16to64': nns_16to64,
124
+ # 'ae': ae,
125
+ # 'decoder': decoder,
126
+ # # 'feature_families': feature_families,
127
+ # 'family_analysis': family_analysis
128
+ # }
129
+
130
  import gradio as gr
131
  import numpy as np
132
  import json
 
140
  from collections import Counter
141
  from huggingface_hub import hf_hub_download
142
  import os
143
+ import networkx as nx
144
+ import plotly.graph_objs as go
145
+ from ast import literal_eval as make_tuple
146
+ import random
147
 
148
  import os
149
  print(os.getenv('MODEL_REPO_ID'))
 
177
  # "csLG_clean_families_64_9216.json",
178
  # "astroPH_clean_families_64_9216.json",
179
  "astroPH_family_analysis_64_9216.json",
180
+ "csLG_family_analysis_64_9216.json",
181
+ "csLG_nns_32to64.json",
182
+ "csLG_nns_16to32.json",
183
+ "csLG_nns_16to64.json",
184
+ "astroPH_nns_32to64.json",
185
+ "astroPH_nns_16to32.json",
186
+ "astroPH_nns_16to64.json",
187
+ "csLG_norms_64_9216_float16.npy",
188
+ "astroPH_norms_64_9216_float16.npy"
189
  ]
190
 
191
  for file in files_to_download:
 
215
  feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
216
  metadata_path = f'data/{subject}_paper_metadata.csv'
217
  topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
218
+ norms_path = f"data/{subject}_norms_{k}_{n_dirs}_float16.npy"
219
  topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
220
  families_path = f"data/{subject}_clean_families_{k}_{n_dirs}.json"
221
  family_analysis_path = f"data/{subject}_family_analysis_{k}_{n_dirs}.json"
222
+ nns_32to64 = json.load(open(f"data/{subject}_nns_32to64.json"))
223
+ nns_16to32 = json.load(open(f"data/{subject}_nns_16to32.json"))
224
+ nns_16to64 = json.load(open(f"data/{subject}_nns_16to64.json"))
225
 
226
  abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
227
  with open(texts_path, 'r') as f:
 
231
  df_metadata = pd.read_csv(metadata_path)
232
  topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
233
  topk_values = np.load(topk_values_path).astype(np.float32)
234
+ norms = np.load(norms_path).astype(np.float32)
235
 
236
  model_filename = f"{subject}_64_9216.pth"
237
  model_path = os.path.join("data", model_filename)
 
255
  'df_metadata': df_metadata,
256
  'topk_indices': topk_indices,
257
  'topk_values': topk_values,
258
+ 'norms': norms,
259
+ 'nns_32to64': nns_32to64,
260
+ 'nns_16to64': nns_16to64,
261
  'ae': ae,
262
  'decoder': decoder,
263
  # 'feature_families': feature_families,
 
312
 
313
  def calculate_co_occurrences(subject, target_index, n_features=9216):
314
  topk_indices = subject_data[subject]['topk_indices']
315
+ norms = subject_data[subject]['norms']
316
 
317
  mask = np.any(topk_indices == target_index, axis=1)
318
  co_occurring_indices = topk_indices[mask].flatten()
319
  co_occurrences = Counter(co_occurring_indices)
320
  del co_occurrences[target_index]
321
+ result = np.zeros(n_features, dtype=np.float32)
322
  result[list(co_occurrences.keys())] = list(co_occurrences.values())
323
+ result[list(co_occurrences.keys())] /= np.minimum(norms[list(co_occurrences.keys())], norms[target_index])
324
  return result
325
 
326
  def style_dataframe(df: pd.DataFrame, is_top: bool) -> pd.DataFrame:
 
442
  "Co-occurrences": topk_values_co_occurrence
443
  })
444
  df_co_occurrences_styled = df_co_occurrences.style.format({
445
+ "Co-occurrences": "{:.2f}" # 2 decimal points
446
  })
447
 
448
+ # Add new code for feature splitting
449
+ nns_16to64 = subject_data[subject]['nns_16to64']
450
+ nns_32to64 = subject_data[subject]['nns_32to64']
451
+
452
+ # Get nearest neighbors for 16 and 32
453
+ #nn_16 = nns_16to64[str(index)]
454
+
455
+ # this is really involved it's a lot easier the other direction
456
+ nn_16 = []
457
+ for key in nns_16to64.keys():
458
+ for match in nns_16to64[key]:
459
+ if index == match['feature'][0]:
460
+ nn_16.append([key, float(match['similarity'])])
461
+
462
+ #nn_32 = nns_32to64[str(index)]
463
+ nn_32 = []
464
+ for key in nns_32to64.keys():
465
+ for match in nns_32to64[key]:
466
+ if index == match['feature'][0]:
467
+ nn_32.append([key, float(match['similarity'])])
468
+
469
+ # Create dataframes for 16 and 32 nearest neighbors
470
+ try:
471
+ df_16 = pd.DataFrame(nn_16, columns=["Feature", "Cosine Similarity"])
472
+ df_16 = df_16.style.format({"Cosine Similarity": "{:.4f}"})
473
+ except:
474
+ df_16 = pd.DataFrame(["No Match"], columns=["Feature"])
475
+
476
+ try:
477
+ df_32 = pd.DataFrame(nn_32, columns=["Feature", "Cosine Similarity"])
478
+ df_32 = df_32.style.format({"Cosine Similarity": "{:.4f}"})
479
+ except:
480
+ df_32 = pd.DataFrame(["No Match"], columns=["Feature"])
481
+
482
+ return output, styled_top_abstracts, df_top_correlated_styled, df_bottom_correlated_styled, df_co_occurrences_styled, fig2, df_16, df_32
483
+
484
+ def create_interactive_directed_graph(family):
485
+ matrix = np.array(family['matrix'])
486
+ matrix[matrix < 0.07] = 0
487
+ densities = family['densities']
488
+ for i in range(len(densities)):
489
+ for j in range(len(densities)):
490
+ if densities[i] < densities[j]:
491
+ matrix[i][j] = 0
492
+
493
+ G = nx.from_numpy_array(matrix, create_using=nx.DiGraph())
494
+
495
+ num_nodes = len(family['feature_f1'])
496
+ all_f1s = family['feature_pearson'] + [family['family_pearson']]
497
+ node_info = {i: {"name": f"{family['feature_names'][i]}", "density": family['densities'][i], "pearson": all_f1s[i]} for i in range(num_nodes)}
498
+ nx.set_node_attributes(G, node_info)
499
+
500
+ # Create node trace
501
+ node_x = []
502
+ node_y = []
503
+ node_text = []
504
+ node_size = []
505
+ node_color = []
506
+ pos = nx.spring_layout(G, k = np.sqrt(1/num_nodes) * 3)
507
+ for node in G.nodes():
508
+ x, y = pos[node]
509
+ node_x.append(x)
510
+ node_y.append(y)
511
+ node_text.append(G.nodes[node]['name'] + "<br>log density: " + str(round(np.log10(G.nodes[node]['density'] + 1e-5), 3)))
512
+ node_size.append((np.log10(G.nodes[node]['density'] + 1e-5) + 6) * 10)
513
+ node_color.append(G.nodes[node]['pearson'])
514
+
515
+ node_trace = go.Scatter(
516
+ x=node_x, y=node_y,
517
+ mode='markers',
518
+ hoverinfo='text',
519
+ marker=dict(
520
+ showscale=True,
521
+ colorscale='purples',
522
+ size=node_size, # Set node marker size to node['f1']
523
+ color=node_color,
524
+ cmin = 0,
525
+ cmax = 1,
526
+ colorbar=dict(
527
+ thickness=15,
528
+ title='Pearson Correlation',
529
+ xanchor='left',
530
+ titleside='right',
531
+ ),
532
+ line_width=2,
533
+ opacity = 1,),
534
+ opacity = 1)
535
+
536
+ node_trace.text = node_text
537
+
538
+ # Create edge trace
539
+ edge_traces = []
540
+ annotations = []
541
+ for edge in G.edges():
542
+ x0, y0 = pos[edge[0]]
543
+ x1, y1 = pos[edge[1]]
544
+ weight = matrix[edge[0], edge[1]]
545
+
546
+ # Calculate offset (adjust this value to move arrows further from or closer to nodes)
547
+ offset = 0.00
548
+ start_x = x0
549
+ start_y = y0
550
+ end_x = x1
551
+ end_y = y1
552
+
553
+ # # Calculate new start and end points
554
+ # if start_x > end_x:
555
+ # start_x = x0 - offset
556
+ # end_x = x0 + offset
557
+ # else:
558
+ # start_x = x0 + offset
559
+ # end_x = x1 - offset
560
+ # if start_y > end_y:
561
+ # start_y = y0 - offset
562
+ # end_y = y1 + offset
563
+ # else:
564
+ # start_y = y0 + offset
565
+ # end_y = y1 - offset
566
+
567
+ edge_trace = go.Scatter(
568
+ x=[start_x, end_x, None],
569
+ y=[start_y, end_y, None],
570
+ line=dict(width=weight * 20, color='#888'), # Multiply weight by 20 for better visibility
571
+ hovertext="weight: " + str(round(weight, 3)), # Set the hover text to the edge weight
572
+ mode='lines',
573
+ line_shape='spline',
574
+ opacity = 0.5,
575
+ )
576
+ edge_traces.append(edge_trace)
577
+
578
+ annotation = dict(
579
+ ax=start_x,
580
+ ay=start_y,
581
+ x=end_x,
582
+ y=end_y,
583
+ xref='x',
584
+ yref='y',
585
+ axref='x',
586
+ ayref='y',
587
+ showarrow=True,
588
+ arrowhead=4,
589
+ arrowsize=4, #max(min(weight * 3, 0.3), 2), # Reduced from 30 to 10
590
+ arrowwidth=1, # Reduced from 30 to 2
591
+ arrowcolor='#999',
592
+ opacity = 1,
593
+ )
594
+ annotations.append(annotation)
595
+
596
+ annotation_trace = go.Scatter(x=[], y=[], mode='markers', hoverinfo='none', marker=dict(opacity=0))
597
+
598
+ # Create the figure
599
+ fig = go.Figure(data=[annotation_trace, *edge_traces, node_trace],
600
+ layout=go.Layout(
601
+ showlegend=False,
602
+ hovermode='closest',
603
+ margin=dict(b=20,l=5,r=5,t=40),
604
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
605
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)),
606
+ )
607
+ fig.update_xaxes(showline=False, linewidth=0, gridcolor='white')
608
+ fig.update_yaxes(showline=False, linewidth=0, gridcolor='white')
609
+ fig.update_layout(
610
+ plot_bgcolor='white',
611
+ annotations=annotations,
612
+ )
613
+ return fig
614
 
615
  # Modify the main interface function
616
  def create_interface():
 
769
  def search_feature_labels(search_text):
770
  if not search_text:
771
  return gr.CheckboxGroup(choices=[])
772
+ matches = [f for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
773
+ matches = sorted(matches, key=lambda x: x['pearson_correlation'], reverse=True)
774
+ matches = [f"{f['label']} ({f['index']})" for f in matches]
775
+
776
  return gr.CheckboxGroup(choices=matches[:10])
777
 
778
  feature_search.change(search_feature_labels, inputs=[feature_search], outputs=[feature_matches])
 
855
  wrap=True
856
  )
857
 
858
+ gr.Markdown("## Feature Splitting")
859
  with gr.Row():
860
  with gr.Column(scale=1):
861
+ gr.Markdown("### Best Match in SAE16")
862
+ nn_16_table = gr.Dataframe(
863
+ headers=["Feature", "Cosine Similarity"],
864
  interactive=False
865
  )
866
  with gr.Column(scale=1):
867
+ gr.Markdown("### Best Match in SAE32")
868
+ nn_32_table = gr.Dataframe(
869
+ headers=["Feature", "Cosine Similarity"],
870
  interactive=False
871
  )
872
+
873
  with gr.Row():
874
  with gr.Column(scale=1):
875
+ gr.Markdown("## Top Co-occurring Features")
876
  co_occurring_features = gr.Dataframe(
877
  headers=["Feature", "Co-occurrences"],
878
  interactive=False
 
881
  gr.Markdown(f"## Activation Value Distribution")
882
  activation_dist = gr.Plot()
883
 
884
+ gr.Markdown("## Correlated Features")
885
+ with gr.Row():
886
+ with gr.Column(scale=1):
887
+ gr.Markdown("### Top Correlated Features")
888
+ top_correlated = gr.Dataframe(
889
+ headers=["Feature", "Cosine similarity"],
890
+ interactive=False
891
+ )
892
+ with gr.Column(scale=1):
893
+ gr.Markdown("### Bottom Correlated Features")
894
+ bottom_correlated = gr.Dataframe(
895
+ headers=["Feature", "Cosine similarity"],
896
+ interactive=False
897
+ )
898
+
899
+
900
+
901
+
902
  def search_feature_labels(search_text, current_subject):
903
  if not search_text:
904
  return gr.CheckboxGroup(choices=[])
905
+ matches = [f for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
906
+ matches = sorted(matches, key=lambda x: x['pearson_correlation'], reverse=True)
907
+ matches = [f"{f['label']} ({f['index']})" for f in matches]
908
+
909
  return gr.CheckboxGroup(choices=matches[:10])
910
 
911
  feature_search.change(search_feature_labels, inputs=[feature_search, subject], outputs=[feature_matches])
 
916
 
917
  # Extract the feature index from the selected feature string
918
  feature_index = int(selected_features[0].split('(')[-1].strip(')'))
919
+ feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, nn_16, nn_32 = visualize_feature(current_subject, feature_index)
920
 
921
  # Return the visualization results along with empty values for search box and checkbox
922
+ return feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, "", [], nn_16, nn_32
923
 
924
  visualize_button.click(
925
  on_visualize,
926
  inputs=[feature_matches, subject],
927
+ outputs=[feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, feature_search, feature_matches, nn_16_table, nn_32_table]
928
  )
929
 
930
  with gr.Tab("Feature Families"):
 
935
  family_matches = gr.CheckboxGroup(label="Matching Feature Families", choices=[])
936
  visualize_family_button = gr.Button("Visualize Feature Family")
937
 
938
+
939
  family_dataframe = gr.Dataframe(
940
+ headers=["Feature", "Parent Co-Occurrence", "F1 Score", "Pearson"],
941
+ datatype=["markdown", "number", "number", "number"],
942
  label="Family and Child Features"
943
  )
944
 
945
+ gr.Markdown("# Family Graph")
946
+ graph_plot = gr.Plot(label="Directed Graph")
947
+
948
+ # family_info = gr.Markdown()
949
 
950
  def search_feature_families(search_text, current_subject):
951
  family_analysis = subject_data[current_subject]['family_analysis']
952
  if not search_text:
953
  return gr.CheckboxGroup(choices=[])
954
+ matches = [family for family in family_analysis if search_text.lower() in family['superfeature'].lower()]
955
+ matches = sorted(matches, key=lambda x: x['family_pearson'], reverse=True)
956
+ matches = [family['superfeature'] for family in matches]
957
+ matches = list(dict.fromkeys(matches))
958
  return gr.CheckboxGroup(choices=matches[:10]) # Limit to top 10 matches
959
 
960
  def visualize_feature_family(selected_families, current_subject):
 
974
  df_data = [
975
  {
976
  "Feature": f"## {family_data['superfeature']}",
977
+ "Parent Co-Occurrence": 1,
978
  "F1 Score": round(family_data['family_f1'], 2),
979
+ "Pearson": round(family_data['family_pearson'], 4)
980
  },
981
  ]
982
 
983
+ coocs = np.array(family_data['matrix'])[:, -1]
984
+ # print(coocs)
985
+ for name, cooc, f1, pearson in zip(family_data['feature_names'], coocs, family_data['feature_f1'], family_data['feature_pearson']):
986
  df_data.append({
987
  "Feature": name,
988
+ "Parent Co-Occurrence": round(cooc, 2),
989
  "F1 Score": round(f1, 2),
990
+ "Pearson": round(pearson, 4)
991
  })
992
 
993
  df = pd.DataFrame(df_data)
 
996
  output += "## Super Reasoning\n"
997
  output += f"{family_data['super_reasoning']}\n\n"
998
 
999
+ graph = create_interactive_directed_graph(family_data)
1000
+
1001
+ #return output, df, "", [], graph # Return empty string for search box and empty list for checkbox
1002
+ return df, "", [], graph
1003
 
1004
  family_search.change(search_feature_families, inputs=[family_search, subject], outputs=[family_matches])
1005
  visualize_family_button.click(
1006
  visualize_feature_family,
1007
  inputs=[family_matches, subject],
1008
+ #outputs=[family_info, family_dataframe, family_search, family_matches, graph_plot]
1009
+ outputs=[family_dataframe, family_search, family_matches, graph_plot]
1010
  )
1011
 
1012