Corey Morris
commited on
Commit
•
c671de9
1
Parent(s):
ed019c6
added MMLU overall average column. added a few charts comparing more moral reasoning and comparing MMLU overall to other data
Browse files
app.py
CHANGED
@@ -25,15 +25,16 @@ class MultiURLData:
|
|
25 |
data = json.load(f)
|
26 |
df = pd.DataFrame(data['results']).T
|
27 |
|
28 |
-
df = df.rename(columns={'acc': model_name})
|
29 |
-
|
30 |
-
df.index = df.index.str.replace('hendrycksTest-', '', regex=True)
|
31 |
|
|
|
|
|
|
|
|
|
32 |
df.index = df.index.str.replace('harness\|', '', regex=True)
|
33 |
-
|
34 |
# remove |5 from the index
|
35 |
df.index = df.index.str.replace('\|5', '', regex=True)
|
36 |
|
|
|
37 |
dataframes.append(df[[model_name]])
|
38 |
|
39 |
data = pd.concat(dataframes, axis=1)
|
@@ -44,7 +45,18 @@ class MultiURLData:
|
|
44 |
cols = cols[-1:] + cols[:-1]
|
45 |
data = data[cols]
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
return data
|
|
|
|
|
48 |
|
49 |
def get_data(self, selected_models):
|
50 |
filtered_data = self.data[self.data['Model Name'].isin(selected_models)]
|
@@ -75,6 +87,7 @@ selected_models = st.multiselect(
|
|
75 |
|
76 |
|
77 |
# Get the filtered data and display it in a table
|
|
|
78 |
filtered_data = data_provider.get_data(selected_models)
|
79 |
st.dataframe(filtered_data)
|
80 |
|
@@ -111,11 +124,34 @@ def create_plot(df, model_column, arc_column, moral_column, models=None):
|
|
111 |
# models_to_plot = ['Model1', 'Model2', 'Model3']
|
112 |
# fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
|
113 |
|
114 |
-
|
115 |
-
st.plotly_chart(fig)
|
116 |
|
117 |
fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'hellaswag|10')
|
118 |
st.plotly_chart(fig)
|
119 |
|
120 |
-
fig = create_plot(filtered_data, 'Model Name', '
|
121 |
st.plotly_chart(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
data = json.load(f)
|
26 |
df = pd.DataFrame(data['results']).T
|
27 |
|
|
|
|
|
|
|
28 |
|
29 |
+
# data cleanup
|
30 |
+
df = df.rename(columns={'acc': model_name})
|
31 |
+
# Replace 'hendrycksTest-' with a more descriptive column name
|
32 |
+
df.index = df.index.str.replace('hendrycksTest-', 'MMLU_', regex=True)
|
33 |
df.index = df.index.str.replace('harness\|', '', regex=True)
|
|
|
34 |
# remove |5 from the index
|
35 |
df.index = df.index.str.replace('\|5', '', regex=True)
|
36 |
|
37 |
+
|
38 |
dataframes.append(df[[model_name]])
|
39 |
|
40 |
data = pd.concat(dataframes, axis=1)
|
|
|
45 |
cols = cols[-1:] + cols[:-1]
|
46 |
data = data[cols]
|
47 |
|
48 |
+
# create a new column that averages the results from each of the columns with a name that start with MMLU
|
49 |
+
data['MMLU_average'] = data.filter(regex='MMLU').mean(axis=1)
|
50 |
+
|
51 |
+
# move the MMLU_average column to the the second column in the dataframe
|
52 |
+
cols = data.columns.tolist()
|
53 |
+
cols = cols[:1] + cols[-1:] + cols[1:-1]
|
54 |
+
data = data[cols]
|
55 |
+
data
|
56 |
+
|
57 |
return data
|
58 |
+
|
59 |
+
|
60 |
|
61 |
def get_data(self, selected_models):
|
62 |
filtered_data = self.data[self.data['Model Name'].isin(selected_models)]
|
|
|
87 |
|
88 |
|
89 |
# Get the filtered data and display it in a table
|
90 |
+
st.header('Sortable table')
|
91 |
filtered_data = data_provider.get_data(selected_models)
|
92 |
st.dataframe(filtered_data)
|
93 |
|
|
|
124 |
# models_to_plot = ['Model1', 'Model2', 'Model3']
|
125 |
# fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'moral_scenarios|5', models=models_to_plot)
|
126 |
|
127 |
+
st.header('Overall benchmark comparison')
|
|
|
128 |
|
129 |
fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'hellaswag|10')
|
130 |
st.plotly_chart(fig)
|
131 |
|
132 |
+
fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'MMLU_average')
|
133 |
st.plotly_chart(fig)
|
134 |
+
|
135 |
+
fig = create_plot(filtered_data, 'Model Name', 'hellaswag|10', 'MMLU_average')
|
136 |
+
st.plotly_chart(fig)
|
137 |
+
|
138 |
+
# Add heading to page to say Moral Scenarios
|
139 |
+
st.header('Moral Scenarios')
|
140 |
+
|
141 |
+
fig = create_plot(filtered_data, 'Model Name', 'arc:challenge|25', 'MMLU_moral_scenarios')
|
142 |
+
st.plotly_chart(fig)
|
143 |
+
|
144 |
+
|
145 |
+
fig = create_plot(filtered_data, 'Model Name', 'MMLU_moral_disputes', 'MMLU_moral_scenarios')
|
146 |
+
st.plotly_chart(fig)
|
147 |
+
|
148 |
+
fig = create_plot(filtered_data, 'Model Name', 'MMLU_average', 'MMLU_moral_scenarios')
|
149 |
+
st.plotly_chart(fig)
|
150 |
+
|
151 |
+
# create a histogram of moral scenarios
|
152 |
+
fig = px.histogram(filtered_data, x="MMLU_moral_scenarios", marginal="rug", hover_data=filtered_data.columns)
|
153 |
+
st.plotly_chart(fig)
|
154 |
+
|
155 |
+
# create a histogram of moral disputes
|
156 |
+
fig = px.histogram(filtered_data, x="MMLU_moral_disputes", marginal="rug", hover_data=filtered_data.columns)
|
157 |
+
st.plotly_chart(fig)
|