Spaces:
Running
Running
charlieoneill
commited on
Commit
•
d6eab4f
1
Parent(s):
3187d23
feature families
Browse files- .gitignore +3 -1
- app.py +382 -27
.gitignore
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
data/
|
|
|
|
|
|
1 |
+
data/
|
2 |
+
__pycache__
|
3 |
+
__pycache__/
|
app.py
CHANGED
@@ -1,3 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import json
|
@@ -11,6 +140,10 @@ import plotly.express as px
|
|
11 |
from collections import Counter
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
import os
|
|
|
|
|
|
|
|
|
14 |
|
15 |
import os
|
16 |
print(os.getenv('MODEL_REPO_ID'))
|
@@ -44,7 +177,15 @@ def download_all_files():
|
|
44 |
# "csLG_clean_families_64_9216.json",
|
45 |
# "astroPH_clean_families_64_9216.json",
|
46 |
"astroPH_family_analysis_64_9216.json",
|
47 |
-
"csLG_family_analysis_64_9216.json"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
]
|
49 |
|
50 |
for file in files_to_download:
|
@@ -74,9 +215,13 @@ def load_subject_data(subject):
|
|
74 |
feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
|
75 |
metadata_path = f'data/{subject}_paper_metadata.csv'
|
76 |
topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
|
|
|
77 |
topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
|
78 |
families_path = f"data/{subject}_clean_families_{k}_{n_dirs}.json"
|
79 |
family_analysis_path = f"data/{subject}_family_analysis_{k}_{n_dirs}.json"
|
|
|
|
|
|
|
80 |
|
81 |
abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
|
82 |
with open(texts_path, 'r') as f:
|
@@ -86,6 +231,7 @@ def load_subject_data(subject):
|
|
86 |
df_metadata = pd.read_csv(metadata_path)
|
87 |
topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
|
88 |
topk_values = np.load(topk_values_path).astype(np.float32)
|
|
|
89 |
|
90 |
model_filename = f"{subject}_64_9216.pth"
|
91 |
model_path = os.path.join("data", model_filename)
|
@@ -109,6 +255,9 @@ def load_subject_data(subject):
|
|
109 |
'df_metadata': df_metadata,
|
110 |
'topk_indices': topk_indices,
|
111 |
'topk_values': topk_values,
|
|
|
|
|
|
|
112 |
'ae': ae,
|
113 |
'decoder': decoder,
|
114 |
# 'feature_families': feature_families,
|
@@ -163,13 +312,15 @@ def get_feature_activations(subject, feature_index, m=5, min_length=100):
|
|
163 |
|
164 |
def calculate_co_occurrences(subject, target_index, n_features=9216):
|
165 |
topk_indices = subject_data[subject]['topk_indices']
|
|
|
166 |
|
167 |
mask = np.any(topk_indices == target_index, axis=1)
|
168 |
co_occurring_indices = topk_indices[mask].flatten()
|
169 |
co_occurrences = Counter(co_occurring_indices)
|
170 |
del co_occurrences[target_index]
|
171 |
-
result = np.zeros(n_features, dtype=
|
172 |
result[list(co_occurrences.keys())] = list(co_occurrences.values())
|
|
|
173 |
return result
|
174 |
|
175 |
def style_dataframe(df: pd.DataFrame, is_top: bool) -> pd.DataFrame:
|
@@ -291,10 +442,175 @@ def visualize_feature(subject, index):
|
|
291 |
"Co-occurrences": topk_values_co_occurrence
|
292 |
})
|
293 |
df_co_occurrences_styled = df_co_occurrences.style.format({
|
294 |
-
"Co-occurrences": "{:.
|
295 |
})
|
296 |
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
# Modify the main interface function
|
300 |
def create_interface():
|
@@ -453,7 +769,10 @@ def create_interface():
|
|
453 |
def search_feature_labels(search_text):
|
454 |
if not search_text:
|
455 |
return gr.CheckboxGroup(choices=[])
|
456 |
-
matches = [f
|
|
|
|
|
|
|
457 |
return gr.CheckboxGroup(choices=matches[:10])
|
458 |
|
459 |
feature_search.change(search_feature_labels, inputs=[feature_search], outputs=[feature_matches])
|
@@ -536,24 +855,24 @@ def create_interface():
|
|
536 |
wrap=True
|
537 |
)
|
538 |
|
539 |
-
gr.Markdown("##
|
540 |
with gr.Row():
|
541 |
with gr.Column(scale=1):
|
542 |
-
gr.Markdown("###
|
543 |
-
|
544 |
-
headers=["Feature", "Cosine
|
545 |
interactive=False
|
546 |
)
|
547 |
with gr.Column(scale=1):
|
548 |
-
gr.Markdown("###
|
549 |
-
|
550 |
-
headers=["Feature", "Cosine
|
551 |
interactive=False
|
552 |
)
|
553 |
-
|
554 |
with gr.Row():
|
555 |
with gr.Column(scale=1):
|
556 |
-
gr.Markdown("## Top
|
557 |
co_occurring_features = gr.Dataframe(
|
558 |
headers=["Feature", "Co-occurrences"],
|
559 |
interactive=False
|
@@ -562,10 +881,31 @@ def create_interface():
|
|
562 |
gr.Markdown(f"## Activation Value Distribution")
|
563 |
activation_dist = gr.Plot()
|
564 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
def search_feature_labels(search_text, current_subject):
|
566 |
if not search_text:
|
567 |
return gr.CheckboxGroup(choices=[])
|
568 |
-
matches = [f
|
|
|
|
|
|
|
569 |
return gr.CheckboxGroup(choices=matches[:10])
|
570 |
|
571 |
feature_search.change(search_feature_labels, inputs=[feature_search, subject], outputs=[feature_matches])
|
@@ -576,15 +916,15 @@ def create_interface():
|
|
576 |
|
577 |
# Extract the feature index from the selected feature string
|
578 |
feature_index = int(selected_features[0].split('(')[-1].strip(')'))
|
579 |
-
feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist = visualize_feature(current_subject, feature_index)
|
580 |
|
581 |
# Return the visualization results along with empty values for search box and checkbox
|
582 |
-
return feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, "", []
|
583 |
|
584 |
visualize_button.click(
|
585 |
on_visualize,
|
586 |
inputs=[feature_matches, subject],
|
587 |
-
outputs=[feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, feature_search, feature_matches]
|
588 |
)
|
589 |
|
590 |
with gr.Tab("Feature Families"):
|
@@ -595,19 +935,26 @@ def create_interface():
|
|
595 |
family_matches = gr.CheckboxGroup(label="Matching Feature Families", choices=[])
|
596 |
visualize_family_button = gr.Button("Visualize Feature Family")
|
597 |
|
598 |
-
|
599 |
family_dataframe = gr.Dataframe(
|
600 |
-
headers=["Feature", "F1 Score", "Pearson
|
601 |
-
datatype=["markdown", "number", "number"],
|
602 |
label="Family and Child Features"
|
603 |
)
|
604 |
|
|
|
|
|
|
|
|
|
605 |
|
606 |
def search_feature_families(search_text, current_subject):
|
607 |
family_analysis = subject_data[current_subject]['family_analysis']
|
608 |
if not search_text:
|
609 |
return gr.CheckboxGroup(choices=[])
|
610 |
-
matches = [family
|
|
|
|
|
|
|
611 |
return gr.CheckboxGroup(choices=matches[:10]) # Limit to top 10 matches
|
612 |
|
613 |
def visualize_feature_family(selected_families, current_subject):
|
@@ -627,16 +974,20 @@ def create_interface():
|
|
627 |
df_data = [
|
628 |
{
|
629 |
"Feature": f"## {family_data['superfeature']}",
|
|
|
630 |
"F1 Score": round(family_data['family_f1'], 2),
|
631 |
-
"Pearson
|
632 |
},
|
633 |
]
|
634 |
|
635 |
-
|
|
|
|
|
636 |
df_data.append({
|
637 |
"Feature": name,
|
|
|
638 |
"F1 Score": round(f1, 2),
|
639 |
-
"Pearson
|
640 |
})
|
641 |
|
642 |
df = pd.DataFrame(df_data)
|
@@ -645,13 +996,17 @@ def create_interface():
|
|
645 |
output += "## Super Reasoning\n"
|
646 |
output += f"{family_data['super_reasoning']}\n\n"
|
647 |
|
648 |
-
|
|
|
|
|
|
|
649 |
|
650 |
family_search.change(search_feature_families, inputs=[family_search, subject], outputs=[family_matches])
|
651 |
visualize_family_button.click(
|
652 |
visualize_feature_family,
|
653 |
inputs=[family_matches, subject],
|
654 |
-
outputs=[family_info, family_dataframe, family_search, family_matches]
|
|
|
655 |
)
|
656 |
|
657 |
|
|
|
1 |
+
# import gradio as gr
|
2 |
+
# import numpy as np
|
3 |
+
# import json
|
4 |
+
# import pandas as pd
|
5 |
+
# from openai import OpenAI
|
6 |
+
# import yaml
|
7 |
+
# from typing import Optional, List, Dict, Tuple, Any
|
8 |
+
# from topk_sae import FastAutoencoder
|
9 |
+
# import torch
|
10 |
+
# import plotly.express as px
|
11 |
+
# from collections import Counter
|
12 |
+
# from huggingface_hub import hf_hub_download
|
13 |
+
# import os
|
14 |
+
# import networkx as nx
|
15 |
+
# import plotly.graph_objs as go
|
16 |
+
# from ast import literal_eval as make_tuple
|
17 |
+
# import random
|
18 |
+
|
19 |
+
# import os
|
20 |
+
# print(os.getenv('MODEL_REPO_ID'))
|
21 |
+
|
22 |
+
# # Constants
|
23 |
+
# EMBEDDING_MODEL = "text-embedding-3-small"
|
24 |
+
# d_model = 1536
|
25 |
+
# n_dirs = d_model * 6
|
26 |
+
# k = 64
|
27 |
+
# auxk = 128
|
28 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
29 |
+
# torch.set_grad_enabled(False)
|
30 |
+
|
31 |
+
# # Function to download all necessary files
|
32 |
+
# def download_all_files():
|
33 |
+
# files_to_download = [
|
34 |
+
# "astroPH_paper_metadata.csv",
|
35 |
+
# "csLG_feature_analysis_results_64.json",
|
36 |
+
# "astroPH_topk_indices_64_9216_int32.npy",
|
37 |
+
# "astroPH_64_9216.pth",
|
38 |
+
# "astroPH_topk_values_64_9216_float16.npy",
|
39 |
+
# "csLG_abstract_texts.json",
|
40 |
+
# "csLG_topk_values_64_9216_float16.npy",
|
41 |
+
# "csLG_abstract_embeddings_float16.npy",
|
42 |
+
# "csLG_paper_metadata.csv",
|
43 |
+
# "csLG_64_9216.pth",
|
44 |
+
# "astroPH_abstract_texts.json",
|
45 |
+
# "astroPH_feature_analysis_results_64.json",
|
46 |
+
# "csLG_topk_indices_64_9216_int32.npy",
|
47 |
+
# "astroPH_abstract_embeddings_float16.npy",
|
48 |
+
# # "csLG_clean_families_64_9216.json",
|
49 |
+
# # "astroPH_clean_families_64_9216.json",
|
50 |
+
# # "astroPH_family_analysis_64_9216.json",
|
51 |
+
# "csLG_family_analysis_64_9216.json"
|
52 |
+
# ]
|
53 |
+
|
54 |
+
# for file in files_to_download:
|
55 |
+
# local_path = os.path.join("data", file)
|
56 |
+
# os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
57 |
+
# hf_hub_download(repo_id="charlieoneill/saerch-ai-data", filename=file, local_dir="data")
|
58 |
+
# print(f"Downloaded {file}")
|
59 |
+
|
60 |
+
# # Load configuration and initialize OpenAI client
|
61 |
+
# download_all_files()
|
62 |
+
|
63 |
+
# # Load the API key from the environment variable
|
64 |
+
# api_key = os.getenv('openai_key')
|
65 |
+
|
66 |
+
# # Ensure the API key is set
|
67 |
+
# if not api_key:
|
68 |
+
# raise ValueError("The environment variable 'openai_key' is not set.")
|
69 |
+
|
70 |
+
# # Initialize the OpenAI client with the API key
|
71 |
+
# client = OpenAI(api_key=api_key)
|
72 |
+
|
73 |
+
# # Function to load data for a specific subject
|
74 |
+
# def load_subject_data(subject):
|
75 |
+
|
76 |
+
# embeddings_path = f"data/{subject}_abstract_embeddings_float16.npy"
|
77 |
+
# texts_path = f"data/{subject}_abstract_texts.json"
|
78 |
+
# feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
|
79 |
+
# metadata_path = f'data/{subject}_paper_metadata.csv'
|
80 |
+
# topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
|
81 |
+
# norms_path = f"data/{subject}_norms_{k}_{n_dirs}.npy"
|
82 |
+
# topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
|
83 |
+
# families_path = f"data/{subject}_clean_families_{k}_{n_dirs}.json"
|
84 |
+
# family_analysis_path = f"data/{subject}_family_analysis_{k}_{n_dirs}.json"
|
85 |
+
# nns_32to64 = json.load(open(f"data/{subject}_nns_32to64.json"))
|
86 |
+
# nns_16to32 = json.load(open(f"data/{subject}_nns_16to32.json"))
|
87 |
+
# nns_16to64 = json.load(open(f"data/{subject}_nns_16to64.json"))
|
88 |
+
|
89 |
+
# abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
|
90 |
+
# with open(texts_path, 'r') as f:
|
91 |
+
# abstract_texts = json.load(f)
|
92 |
+
# with open(feature_analysis_path, 'r') as f:
|
93 |
+
# feature_analysis = json.load(f)
|
94 |
+
# df_metadata = pd.read_csv(metadata_path)
|
95 |
+
# topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
|
96 |
+
# topk_values = np.load(topk_values_path).astype(np.float32)
|
97 |
+
# norms = np.load(norms_path).astype(np.float32)
|
98 |
+
|
99 |
+
# model_filename = f"{subject}_64_9216.pth"
|
100 |
+
# model_path = os.path.join("data", model_filename)
|
101 |
+
|
102 |
+
# ae = FastAutoencoder(n_dirs, d_model, k, auxk, multik=0).to(device)
|
103 |
+
# ae.load_state_dict(torch.load(model_path))
|
104 |
+
# ae.eval()
|
105 |
+
|
106 |
+
# weights = torch.load(model_path)
|
107 |
+
# decoder = weights['decoder.weight'].cpu().numpy()
|
108 |
+
# del weights
|
109 |
+
|
110 |
+
# with open(family_analysis_path, 'r') as f:
|
111 |
+
# family_analysis = json.load(f)
|
112 |
+
|
113 |
+
|
114 |
+
# return {
|
115 |
+
# 'abstract_embeddings': abstract_embeddings,
|
116 |
+
# 'abstract_texts': abstract_texts,
|
117 |
+
# 'feature_analysis': feature_analysis,
|
118 |
+
# 'df_metadata': df_metadata,
|
119 |
+
# 'topk_indices': topk_indices,
|
120 |
+
# 'topk_values': topk_values,
|
121 |
+
# 'norms': norms,
|
122 |
+
# 'nns_32to64': nns_32to64,
|
123 |
+
# 'nns_16to64': nns_16to64,
|
124 |
+
# 'ae': ae,
|
125 |
+
# 'decoder': decoder,
|
126 |
+
# # 'feature_families': feature_families,
|
127 |
+
# 'family_analysis': family_analysis
|
128 |
+
# }
|
129 |
+
|
130 |
import gradio as gr
|
131 |
import numpy as np
|
132 |
import json
|
|
|
140 |
from collections import Counter
|
141 |
from huggingface_hub import hf_hub_download
|
142 |
import os
|
143 |
+
import networkx as nx
|
144 |
+
import plotly.graph_objs as go
|
145 |
+
from ast import literal_eval as make_tuple
|
146 |
+
import random
|
147 |
|
148 |
import os
|
149 |
print(os.getenv('MODEL_REPO_ID'))
|
|
|
177 |
# "csLG_clean_families_64_9216.json",
|
178 |
# "astroPH_clean_families_64_9216.json",
|
179 |
"astroPH_family_analysis_64_9216.json",
|
180 |
+
"csLG_family_analysis_64_9216.json",
|
181 |
+
"csLG_nns_32to64.json",
|
182 |
+
"csLG_nns_16to32.json",
|
183 |
+
"csLG_nns_16to64.json",
|
184 |
+
"astroPH_nns_32to64.json",
|
185 |
+
"astroPH_nns_16to32.json",
|
186 |
+
"astroPH_nns_16to64.json",
|
187 |
+
"csLG_norms_64_9216_float16.npy",
|
188 |
+
"astroPH_norms_64_9216_float16.npy"
|
189 |
]
|
190 |
|
191 |
for file in files_to_download:
|
|
|
215 |
feature_analysis_path = f"data/{subject}_feature_analysis_results_{k}.json"
|
216 |
metadata_path = f'data/{subject}_paper_metadata.csv'
|
217 |
topk_indices_path = f"data/{subject}_topk_indices_{k}_{n_dirs}_int32.npy"
|
218 |
+
norms_path = f"data/{subject}_norms_{k}_{n_dirs}_float16.npy"
|
219 |
topk_values_path = f"data/{subject}_topk_values_{k}_{n_dirs}_float16.npy"
|
220 |
families_path = f"data/{subject}_clean_families_{k}_{n_dirs}.json"
|
221 |
family_analysis_path = f"data/{subject}_family_analysis_{k}_{n_dirs}.json"
|
222 |
+
nns_32to64 = json.load(open(f"data/{subject}_nns_32to64.json"))
|
223 |
+
nns_16to32 = json.load(open(f"data/{subject}_nns_16to32.json"))
|
224 |
+
nns_16to64 = json.load(open(f"data/{subject}_nns_16to64.json"))
|
225 |
|
226 |
abstract_embeddings = np.load(embeddings_path).astype(np.float32) # Load float16 and convert to float32
|
227 |
with open(texts_path, 'r') as f:
|
|
|
231 |
df_metadata = pd.read_csv(metadata_path)
|
232 |
topk_indices = np.load(topk_indices_path) # Already in int32, no conversion needed
|
233 |
topk_values = np.load(topk_values_path).astype(np.float32)
|
234 |
+
norms = np.load(norms_path).astype(np.float32)
|
235 |
|
236 |
model_filename = f"{subject}_64_9216.pth"
|
237 |
model_path = os.path.join("data", model_filename)
|
|
|
255 |
'df_metadata': df_metadata,
|
256 |
'topk_indices': topk_indices,
|
257 |
'topk_values': topk_values,
|
258 |
+
'norms': norms,
|
259 |
+
'nns_32to64': nns_32to64,
|
260 |
+
'nns_16to64': nns_16to64,
|
261 |
'ae': ae,
|
262 |
'decoder': decoder,
|
263 |
# 'feature_families': feature_families,
|
|
|
312 |
|
313 |
def calculate_co_occurrences(subject, target_index, n_features=9216):
|
314 |
topk_indices = subject_data[subject]['topk_indices']
|
315 |
+
norms = subject_data[subject]['norms']
|
316 |
|
317 |
mask = np.any(topk_indices == target_index, axis=1)
|
318 |
co_occurring_indices = topk_indices[mask].flatten()
|
319 |
co_occurrences = Counter(co_occurring_indices)
|
320 |
del co_occurrences[target_index]
|
321 |
+
result = np.zeros(n_features, dtype=np.float32)
|
322 |
result[list(co_occurrences.keys())] = list(co_occurrences.values())
|
323 |
+
result[list(co_occurrences.keys())] /= np.minimum(norms[list(co_occurrences.keys())], norms[target_index])
|
324 |
return result
|
325 |
|
326 |
def style_dataframe(df: pd.DataFrame, is_top: bool) -> pd.DataFrame:
|
|
|
442 |
"Co-occurrences": topk_values_co_occurrence
|
443 |
})
|
444 |
df_co_occurrences_styled = df_co_occurrences.style.format({
|
445 |
+
"Co-occurrences": "{:.2f}" # 2 decimal points
|
446 |
})
|
447 |
|
448 |
+
# Add new code for feature splitting
|
449 |
+
nns_16to64 = subject_data[subject]['nns_16to64']
|
450 |
+
nns_32to64 = subject_data[subject]['nns_32to64']
|
451 |
+
|
452 |
+
# Get nearest neighbors for 16 and 32
|
453 |
+
#nn_16 = nns_16to64[str(index)]
|
454 |
+
|
455 |
+
# this is really involved it's a lot easier the other direction
|
456 |
+
nn_16 = []
|
457 |
+
for key in nns_16to64.keys():
|
458 |
+
for match in nns_16to64[key]:
|
459 |
+
if index == match['feature'][0]:
|
460 |
+
nn_16.append([key, float(match['similarity'])])
|
461 |
+
|
462 |
+
#nn_32 = nns_32to64[str(index)]
|
463 |
+
nn_32 = []
|
464 |
+
for key in nns_32to64.keys():
|
465 |
+
for match in nns_32to64[key]:
|
466 |
+
if index == match['feature'][0]:
|
467 |
+
nn_32.append([key, float(match['similarity'])])
|
468 |
+
|
469 |
+
# Create dataframes for 16 and 32 nearest neighbors
|
470 |
+
try:
|
471 |
+
df_16 = pd.DataFrame(nn_16, columns=["Feature", "Cosine Similarity"])
|
472 |
+
df_16 = df_16.style.format({"Cosine Similarity": "{:.4f}"})
|
473 |
+
except:
|
474 |
+
df_16 = pd.DataFrame(["No Match"], columns=["Feature"])
|
475 |
+
|
476 |
+
try:
|
477 |
+
df_32 = pd.DataFrame(nn_32, columns=["Feature", "Cosine Similarity"])
|
478 |
+
df_32 = df_32.style.format({"Cosine Similarity": "{:.4f}"})
|
479 |
+
except:
|
480 |
+
df_32 = pd.DataFrame(["No Match"], columns=["Feature"])
|
481 |
+
|
482 |
+
return output, styled_top_abstracts, df_top_correlated_styled, df_bottom_correlated_styled, df_co_occurrences_styled, fig2, df_16, df_32
|
483 |
+
|
484 |
+
def create_interactive_directed_graph(family):
|
485 |
+
matrix = np.array(family['matrix'])
|
486 |
+
matrix[matrix < 0.07] = 0
|
487 |
+
densities = family['densities']
|
488 |
+
for i in range(len(densities)):
|
489 |
+
for j in range(len(densities)):
|
490 |
+
if densities[i] < densities[j]:
|
491 |
+
matrix[i][j] = 0
|
492 |
+
|
493 |
+
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph())
|
494 |
+
|
495 |
+
num_nodes = len(family['feature_f1'])
|
496 |
+
all_f1s = family['feature_pearson'] + [family['family_pearson']]
|
497 |
+
node_info = {i: {"name": f"{family['feature_names'][i]}", "density": family['densities'][i], "pearson": all_f1s[i]} for i in range(num_nodes)}
|
498 |
+
nx.set_node_attributes(G, node_info)
|
499 |
+
|
500 |
+
# Create node trace
|
501 |
+
node_x = []
|
502 |
+
node_y = []
|
503 |
+
node_text = []
|
504 |
+
node_size = []
|
505 |
+
node_color = []
|
506 |
+
pos = nx.spring_layout(G, k = np.sqrt(1/num_nodes) * 3)
|
507 |
+
for node in G.nodes():
|
508 |
+
x, y = pos[node]
|
509 |
+
node_x.append(x)
|
510 |
+
node_y.append(y)
|
511 |
+
node_text.append(G.nodes[node]['name'] + "<br>log density: " + str(round(np.log10(G.nodes[node]['density'] + 1e-5), 3)))
|
512 |
+
node_size.append((np.log10(G.nodes[node]['density'] + 1e-5) + 6) * 10)
|
513 |
+
node_color.append(G.nodes[node]['pearson'])
|
514 |
+
|
515 |
+
node_trace = go.Scatter(
|
516 |
+
x=node_x, y=node_y,
|
517 |
+
mode='markers',
|
518 |
+
hoverinfo='text',
|
519 |
+
marker=dict(
|
520 |
+
showscale=True,
|
521 |
+
colorscale='purples',
|
522 |
+
size=node_size, # Set node marker size to node['f1']
|
523 |
+
color=node_color,
|
524 |
+
cmin = 0,
|
525 |
+
cmax = 1,
|
526 |
+
colorbar=dict(
|
527 |
+
thickness=15,
|
528 |
+
title='Pearson Correlation',
|
529 |
+
xanchor='left',
|
530 |
+
titleside='right',
|
531 |
+
),
|
532 |
+
line_width=2,
|
533 |
+
opacity = 1,),
|
534 |
+
opacity = 1)
|
535 |
+
|
536 |
+
node_trace.text = node_text
|
537 |
+
|
538 |
+
# Create edge trace
|
539 |
+
edge_traces = []
|
540 |
+
annotations = []
|
541 |
+
for edge in G.edges():
|
542 |
+
x0, y0 = pos[edge[0]]
|
543 |
+
x1, y1 = pos[edge[1]]
|
544 |
+
weight = matrix[edge[0], edge[1]]
|
545 |
+
|
546 |
+
# Calculate offset (adjust this value to move arrows further from or closer to nodes)
|
547 |
+
offset = 0.00
|
548 |
+
start_x = x0
|
549 |
+
start_y = y0
|
550 |
+
end_x = x1
|
551 |
+
end_y = y1
|
552 |
+
|
553 |
+
# # Calculate new start and end points
|
554 |
+
# if start_x > end_x:
|
555 |
+
# start_x = x0 - offset
|
556 |
+
# end_x = x0 + offset
|
557 |
+
# else:
|
558 |
+
# start_x = x0 + offset
|
559 |
+
# end_x = x1 - offset
|
560 |
+
# if start_y > end_y:
|
561 |
+
# start_y = y0 - offset
|
562 |
+
# end_y = y1 + offset
|
563 |
+
# else:
|
564 |
+
# start_y = y0 + offset
|
565 |
+
# end_y = y1 - offset
|
566 |
+
|
567 |
+
edge_trace = go.Scatter(
|
568 |
+
x=[start_x, end_x, None],
|
569 |
+
y=[start_y, end_y, None],
|
570 |
+
line=dict(width=weight * 20, color='#888'), # Multiply weight by 20 for better visibility
|
571 |
+
hovertext="weight: " + str(round(weight, 3)), # Set the hover text to the edge weight
|
572 |
+
mode='lines',
|
573 |
+
line_shape='spline',
|
574 |
+
opacity = 0.5,
|
575 |
+
)
|
576 |
+
edge_traces.append(edge_trace)
|
577 |
+
|
578 |
+
annotation = dict(
|
579 |
+
ax=start_x,
|
580 |
+
ay=start_y,
|
581 |
+
x=end_x,
|
582 |
+
y=end_y,
|
583 |
+
xref='x',
|
584 |
+
yref='y',
|
585 |
+
axref='x',
|
586 |
+
ayref='y',
|
587 |
+
showarrow=True,
|
588 |
+
arrowhead=4,
|
589 |
+
arrowsize=4, #max(min(weight * 3, 0.3), 2), # Reduced from 30 to 10
|
590 |
+
arrowwidth=1, # Reduced from 30 to 2
|
591 |
+
arrowcolor='#999',
|
592 |
+
opacity = 1,
|
593 |
+
)
|
594 |
+
annotations.append(annotation)
|
595 |
+
|
596 |
+
annotation_trace = go.Scatter(x=[], y=[], mode='markers', hoverinfo='none', marker=dict(opacity=0))
|
597 |
+
|
598 |
+
# Create the figure
|
599 |
+
fig = go.Figure(data=[annotation_trace, *edge_traces, node_trace],
|
600 |
+
layout=go.Layout(
|
601 |
+
showlegend=False,
|
602 |
+
hovermode='closest',
|
603 |
+
margin=dict(b=20,l=5,r=5,t=40),
|
604 |
+
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
605 |
+
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)),
|
606 |
+
)
|
607 |
+
fig.update_xaxes(showline=False, linewidth=0, gridcolor='white')
|
608 |
+
fig.update_yaxes(showline=False, linewidth=0, gridcolor='white')
|
609 |
+
fig.update_layout(
|
610 |
+
plot_bgcolor='white',
|
611 |
+
annotations=annotations,
|
612 |
+
)
|
613 |
+
return fig
|
614 |
|
615 |
# Modify the main interface function
|
616 |
def create_interface():
|
|
|
769 |
def search_feature_labels(search_text):
|
770 |
if not search_text:
|
771 |
return gr.CheckboxGroup(choices=[])
|
772 |
+
matches = [f for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
|
773 |
+
matches = sorted(matches, key=lambda x: x['pearson_correlation'], reverse=True)
|
774 |
+
matches = [f"{f['label']} ({f['index']})" for f in matches]
|
775 |
+
|
776 |
return gr.CheckboxGroup(choices=matches[:10])
|
777 |
|
778 |
feature_search.change(search_feature_labels, inputs=[feature_search], outputs=[feature_matches])
|
|
|
855 |
wrap=True
|
856 |
)
|
857 |
|
858 |
+
gr.Markdown("## Feature Splitting")
|
859 |
with gr.Row():
|
860 |
with gr.Column(scale=1):
|
861 |
+
gr.Markdown("### Best Match in SAE16")
|
862 |
+
nn_16_table = gr.Dataframe(
|
863 |
+
headers=["Feature", "Cosine Similarity"],
|
864 |
interactive=False
|
865 |
)
|
866 |
with gr.Column(scale=1):
|
867 |
+
gr.Markdown("### Best Match in SAE32")
|
868 |
+
nn_32_table = gr.Dataframe(
|
869 |
+
headers=["Feature", "Cosine Similarity"],
|
870 |
interactive=False
|
871 |
)
|
872 |
+
|
873 |
with gr.Row():
|
874 |
with gr.Column(scale=1):
|
875 |
+
gr.Markdown("## Top Co-occurring Features")
|
876 |
co_occurring_features = gr.Dataframe(
|
877 |
headers=["Feature", "Co-occurrences"],
|
878 |
interactive=False
|
|
|
881 |
gr.Markdown(f"## Activation Value Distribution")
|
882 |
activation_dist = gr.Plot()
|
883 |
|
884 |
+
gr.Markdown("## Correlated Features")
|
885 |
+
with gr.Row():
|
886 |
+
with gr.Column(scale=1):
|
887 |
+
gr.Markdown("### Top Correlated Features")
|
888 |
+
top_correlated = gr.Dataframe(
|
889 |
+
headers=["Feature", "Cosine similarity"],
|
890 |
+
interactive=False
|
891 |
+
)
|
892 |
+
with gr.Column(scale=1):
|
893 |
+
gr.Markdown("### Bottom Correlated Features")
|
894 |
+
bottom_correlated = gr.Dataframe(
|
895 |
+
headers=["Feature", "Cosine similarity"],
|
896 |
+
interactive=False
|
897 |
+
)
|
898 |
+
|
899 |
+
|
900 |
+
|
901 |
+
|
902 |
def search_feature_labels(search_text, current_subject):
|
903 |
if not search_text:
|
904 |
return gr.CheckboxGroup(choices=[])
|
905 |
+
matches = [f for f in subject_data[current_subject]['feature_analysis'] if search_text.lower() in f['label'].lower()]
|
906 |
+
matches = sorted(matches, key=lambda x: x['pearson_correlation'], reverse=True)
|
907 |
+
matches = [f"{f['label']} ({f['index']})" for f in matches]
|
908 |
+
|
909 |
return gr.CheckboxGroup(choices=matches[:10])
|
910 |
|
911 |
feature_search.change(search_feature_labels, inputs=[feature_search, subject], outputs=[feature_matches])
|
|
|
916 |
|
917 |
# Extract the feature index from the selected feature string
|
918 |
feature_index = int(selected_features[0].split('(')[-1].strip(')'))
|
919 |
+
feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, nn_16, nn_32 = visualize_feature(current_subject, feature_index)
|
920 |
|
921 |
# Return the visualization results along with empty values for search box and checkbox
|
922 |
+
return feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, "", [], nn_16, nn_32
|
923 |
|
924 |
visualize_button.click(
|
925 |
on_visualize,
|
926 |
inputs=[feature_matches, subject],
|
927 |
+
outputs=[feature_info, top_abstracts, top_correlated, bottom_correlated, co_occurring_features, activation_dist, feature_search, feature_matches, nn_16_table, nn_32_table]
|
928 |
)
|
929 |
|
930 |
with gr.Tab("Feature Families"):
|
|
|
935 |
family_matches = gr.CheckboxGroup(label="Matching Feature Families", choices=[])
|
936 |
visualize_family_button = gr.Button("Visualize Feature Family")
|
937 |
|
938 |
+
|
939 |
family_dataframe = gr.Dataframe(
|
940 |
+
headers=["Feature", "Parent Co-Occurrence", "F1 Score", "Pearson"],
|
941 |
+
datatype=["markdown", "number", "number", "number"],
|
942 |
label="Family and Child Features"
|
943 |
)
|
944 |
|
945 |
+
gr.Markdown("# Family Graph")
|
946 |
+
graph_plot = gr.Plot(label="Directed Graph")
|
947 |
+
|
948 |
+
# family_info = gr.Markdown()
|
949 |
|
950 |
def search_feature_families(search_text, current_subject):
|
951 |
family_analysis = subject_data[current_subject]['family_analysis']
|
952 |
if not search_text:
|
953 |
return gr.CheckboxGroup(choices=[])
|
954 |
+
matches = [family for family in family_analysis if search_text.lower() in family['superfeature'].lower()]
|
955 |
+
matches = sorted(matches, key=lambda x: x['family_pearson'], reverse=True)
|
956 |
+
matches = [family['superfeature'] for family in matches]
|
957 |
+
matches = list(dict.fromkeys(matches))
|
958 |
return gr.CheckboxGroup(choices=matches[:10]) # Limit to top 10 matches
|
959 |
|
960 |
def visualize_feature_family(selected_families, current_subject):
|
|
|
974 |
df_data = [
|
975 |
{
|
976 |
"Feature": f"## {family_data['superfeature']}",
|
977 |
+
"Parent Co-Occurrence": 1,
|
978 |
"F1 Score": round(family_data['family_f1'], 2),
|
979 |
+
"Pearson": round(family_data['family_pearson'], 4)
|
980 |
},
|
981 |
]
|
982 |
|
983 |
+
coocs = np.array(family_data['matrix'])[:, -1]
|
984 |
+
# print(coocs)
|
985 |
+
for name, cooc, f1, pearson in zip(family_data['feature_names'], coocs, family_data['feature_f1'], family_data['feature_pearson']):
|
986 |
df_data.append({
|
987 |
"Feature": name,
|
988 |
+
"Parent Co-Occurrence": round(cooc, 2),
|
989 |
"F1 Score": round(f1, 2),
|
990 |
+
"Pearson": round(pearson, 4)
|
991 |
})
|
992 |
|
993 |
df = pd.DataFrame(df_data)
|
|
|
996 |
output += "## Super Reasoning\n"
|
997 |
output += f"{family_data['super_reasoning']}\n\n"
|
998 |
|
999 |
+
graph = create_interactive_directed_graph(family_data)
|
1000 |
+
|
1001 |
+
#return output, df, "", [], graph # Return empty string for search box and empty list for checkbox
|
1002 |
+
return df, "", [], graph
|
1003 |
|
1004 |
family_search.change(search_feature_families, inputs=[family_search, subject], outputs=[family_matches])
|
1005 |
visualize_family_button.click(
|
1006 |
visualize_feature_family,
|
1007 |
inputs=[family_matches, subject],
|
1008 |
+
#outputs=[family_info, family_dataframe, family_search, family_matches, graph_plot]
|
1009 |
+
outputs=[family_dataframe, family_search, family_matches, graph_plot]
|
1010 |
)
|
1011 |
|
1012 |
|