zixianma commited on
Commit
95d0904
1 Parent(s): baea0cd

added application file along with data

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +1231 -0
  2. db/2d/2d-how-many/embeddings.pkl +3 -0
  3. db/2d/2d-how-many/expanded_data.csv +0 -0
  4. db/2d/2d-how-many/gt.pkl +3 -0
  5. db/2d/2d-how-many/instructblip_vicuna13b_surprise.pkl +3 -0
  6. db/2d/2d-how-many/instructblip_vicuna7b_surprise.pkl +3 -0
  7. db/2d/2d-how-many/llava15_13b_surprise.pkl +3 -0
  8. db/2d/2d-how-many/llava15_7b_surprise.pkl +3 -0
  9. db/2d/2d-how-many/merged_data.csv +0 -0
  10. db/2d/2d-how-many/path.json +0 -0
  11. db/2d/2d-how-many/qa.pkl +3 -0
  12. db/2d/2d-how-many/qwenvl_chat_surprise.pkl +3 -0
  13. db/2d/2d-how-many/qwenvl_surprise.pkl +3 -0
  14. db/2d/2d-how-many/task_plan.pkl +3 -0
  15. db/2d/2d-what-attribute/embeddings.pkl +3 -0
  16. db/2d/2d-what-attribute/expanded_data.csv +0 -0
  17. db/2d/2d-what-attribute/gt.pkl +3 -0
  18. db/2d/2d-what-attribute/instructblip_vicuna13b_surprise.pkl +3 -0
  19. db/2d/2d-what-attribute/instructblip_vicuna7b_surprise.pkl +3 -0
  20. db/2d/2d-what-attribute/llava15_13b_surprise.pkl +3 -0
  21. db/2d/2d-what-attribute/llava15_7b_surprise.pkl +3 -0
  22. db/2d/2d-what-attribute/merged_data.csv +0 -0
  23. db/2d/2d-what-attribute/path.json +0 -0
  24. db/2d/2d-what-attribute/qa.pkl +3 -0
  25. db/2d/2d-what-attribute/qwenvl_chat_surprise.pkl +3 -0
  26. db/2d/2d-what-attribute/qwenvl_surprise.pkl +3 -0
  27. db/2d/2d-what-attribute/task_plan.pkl +3 -0
  28. db/2d/2d-what/embeddings.pkl +3 -0
  29. db/2d/2d-what/expanded_data.csv +0 -0
  30. db/2d/2d-what/gt.pkl +3 -0
  31. db/2d/2d-what/instructblip_vicuna13b_surprise.pkl +3 -0
  32. db/2d/2d-what/instructblip_vicuna7b_surprise.pkl +3 -0
  33. db/2d/2d-what/llava15_13b_surprise.pkl +3 -0
  34. db/2d/2d-what/llava15_7b_surprise.pkl +3 -0
  35. db/2d/2d-what/merged_data.csv +0 -0
  36. db/2d/2d-what/path.json +0 -0
  37. db/2d/2d-what/qa.pkl +3 -0
  38. db/2d/2d-what/qwenvl_chat_surprise.pkl +3 -0
  39. db/2d/2d-what/qwenvl_surprise.pkl +3 -0
  40. db/2d/2d-what/task_plan.pkl +3 -0
  41. db/2d/2d-where-attribute/embeddings.pkl +3 -0
  42. db/2d/2d-where-attribute/expanded_data.csv +0 -0
  43. db/2d/2d-where-attribute/gt.pkl +3 -0
  44. db/2d/2d-where-attribute/instructblip_vicuna13b_surprise.pkl +3 -0
  45. db/2d/2d-where-attribute/instructblip_vicuna7b_surprise.pkl +3 -0
  46. db/2d/2d-where-attribute/llava15_13b_surprise.pkl +3 -0
  47. db/2d/2d-where-attribute/llava15_7b_surprise.pkl +3 -0
  48. db/2d/2d-where-attribute/merged_data.csv +0 -0
  49. db/2d/2d-where-attribute/path.json +0 -0
  50. db/2d/2d-where-attribute/qa.pkl +3 -0
app.py ADDED
@@ -0,0 +1,1231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import os
4
+ from copy import deepcopy
5
+
6
+ import gradio as gr
7
+ import altair as alt
8
+ alt.data_transformers.enable("vegafusion")
9
+ from dynabench.task_evaluator import *
10
+
11
+ BASE_DIR = "../db"
12
+ MODELS = ['qwenvl-chat', 'qwenvl', 'llava15-7b', 'llava15-13b', 'instructblip-vicuna13b', 'instructblip-vicuna7b']
13
+ VIDEO_MODELS = ['video-chat2-7b','video-llama2-7b','video-llama2-13b','chat-univi-7b','chat-univi-13b','video-llava-7b','video-chatgpt-7b']
14
+ domains = ["imageqa-2d-sticker", "imageqa-3d-tabletop", "imageqa-scene-graph", "videoqa-3d-tabletop", "videoqa-scene-graph"]
15
+ domain2folder = {"imageqa-2d-sticker": "2d",
16
+ "imageqa-3d-tabletop": "3d",
17
+ "imageqa-scene-graph": "sg",
18
+ "videoqa-3d-tabletop": "video-3d",
19
+ "videoqa-scene-graph": "video-sg",
20
+ None: '2d'}
21
+
22
+ def update_partition_and_models(domain):
23
+ domain = domain2folder[domain]
24
+ path = f"{BASE_DIR}/{domain}"
25
+
26
+
27
+ if os.path.exists(path):
28
+ partitions = list_directories(path)
29
+ if domain.find("video") > -1:
30
+ model = gr.Dropdown(VIDEO_MODELS, value=VIDEO_MODELS[0], label="model")
31
+ else:
32
+ model = gr.Dropdown(MODELS, value=MODELS[0], label="model")
33
+
34
+ partition = gr.Dropdown(partitions, value=partitions[0], label="task space of the following task generator")
35
+ return [partition, model]
36
+ else:
37
+ partition = gr.Dropdown([], value=None, label="task space of the following task generator")
38
+ model = gr.Dropdown([], value=None, label="model")
39
+ return [partition, model]
40
+
41
+ def update_partition_and_models_and_baselines(domain):
42
+ domain = domain2folder[domain]
43
+ path = f"{BASE_DIR}/{domain}"
44
+
45
+ if os.path.exists(path):
46
+ partitions = list_directories(path)
47
+ if domain.find("video") > -1:
48
+ model = gr.Dropdown(VIDEO_MODELS, value=VIDEO_MODELS[0], label="model")
49
+ baseline = gr.Dropdown(VIDEO_MODELS, value=VIDEO_MODELS[0], label="baseline")
50
+ else:
51
+ model = gr.Dropdown(MODELS, value=MODELS[0], label="model")
52
+ baseline = gr.Dropdown(MODELS, value=MODELS[0], label="baseline")
53
+
54
+ partition = gr.Dropdown(partitions, value=partitions[0], label="task space of the following task generator")
55
+ else:
56
+ partition = gr.Dropdown([], value=None, label="task space of the following task generator")
57
+ model = gr.Dropdown([], value=None, label="model")
58
+ baseline = gr.Dropdown([], value=None, label="baseline")
59
+ return [partition, model, baseline]
60
+
61
+ def get_filtered_task_ids(domain, partition, models, rank, k, threshold, baseline):
62
+ domain = domain2folder[domain]
63
+ data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
64
+ if not os.path.exists(data_path):
65
+ return []
66
+ else:
67
+ merged_df = pd.read_csv(data_path)
68
+ merged_df.rename(columns={'llavav1.5-7b': 'llava15-7b', 'llavav1.5-13b': 'llava15-13b'}, inplace=True)
69
+
70
+ df = merged_df
71
+
72
+ select_top = rank == "top"
73
+ # Model X is good / bad at
74
+ for model in models:
75
+ if baseline:
76
+ df = df[df[model] >= df[baseline]]
77
+ else:
78
+ if select_top:
79
+ df = df[df[model] >= threshold]
80
+ else:
81
+ df = df[df[model] <= threshold]
82
+ if not baseline:
83
+ df['mean score'] = df[models].mean(axis=1)
84
+ df = df.sort_values(by='mean score', ascending=False)
85
+ df = df.iloc[:k, :] if select_top else df.iloc[-k:, :]
86
+
87
+ task_ids = list(df.index)
88
+ return task_ids
89
+
90
+ def plot_patterns(domain, partition, models, rank, k, threshold, baseline, pattern, order):
91
+ domain = domain2folder[domain]
92
+ data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
93
+ if not os.path.exists(data_path):
94
+ return None
95
+ task_ids = get_filtered_task_ids(domain, partition, models, rank, k, threshold, baseline)
96
+ expand_df = pd.read_csv(data_path)
97
+
98
+ chart_df = expand_df[expand_df['model'].isin((models + [baseline]) if baseline else models)]
99
+ chart_df = chart_df[chart_df['task id'].isin(task_ids)]
100
+ print(pattern)
101
+ freq, cols = eval(pattern)
102
+ pattern_str = ""
103
+ df = chart_df
104
+ for col in cols:
105
+ col_name, col_val = col
106
+ try:
107
+ col_val = int(col_val)
108
+ except:
109
+ col_val = col_val
110
+ df = df[df[col_name] == col_val]
111
+ pattern_str += f"{col_name} = {col_val}, "
112
+ print(len(df))
113
+
114
+ if baseline:
115
+ model_str = (', '.join(models) if len(models) > 1 else models[0])
116
+ phrase = f'{model_str} perform' if len(models) > 1 else f'{model_str} performs'
117
+ title = f"{phrase} better than {baseline} on {freq} tasks where {pattern_str[:-2]}"
118
+ else:
119
+ title = f"Models are {'best' if rank == 'top' else 'worst'} at {freq} tasks where {pattern_str[:-2]}"
120
+
121
+ chart = alt.Chart(df).mark_bar().encode(
122
+ alt.X('model:N',
123
+ sort=alt.EncodingSortField(field=f'score', order=order, op="mean"),
124
+ axis=alt.Axis(labels=False, tickSize=0)), # no title, no label angle),
125
+ alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
126
+ alt.Color('model:N').legend(),
127
+ ).properties(
128
+ width=400,
129
+ height=300,
130
+ title=title
131
+ )
132
+ return chart
133
+
134
+ def plot_embedding(domain, partition, category):
135
+ domain = domain2folder[domain]
136
+ data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
137
+
138
+ if os.path.exists(data_path):
139
+ merged_df = pd.read_csv(data_path)
140
+ # models = merged_df.columns
141
+ has_image = 'image' in merged_df
142
+ chart = alt.Chart(merged_df).mark_point(size=30, filled=True).encode(
143
+ alt.OpacityValue(0.5),
144
+ alt.X('x:Q'),
145
+ alt.Y('y:Q'),
146
+ alt.Color(f'{category}:N'),
147
+ tooltip=['question', 'answer'] + (['image'] if has_image else []),
148
+ ).properties(
149
+ width=800,
150
+ height=800,
151
+ title="UMAP Projected Task Embeddings"
152
+ ).configure_axis(
153
+ labelFontSize=25,
154
+ titleFontSize=25,
155
+ ).configure_title(
156
+ fontSize=40
157
+ ).configure_legend(
158
+ labelFontSize=25,
159
+ titleFontSize=25,
160
+ ).interactive()
161
+ return chart
162
+ else:
163
+ return None
164
+
165
+
166
+
167
+ def plot_multi_models(domain, partition, category, cat_options, models, order, pattern, aggregate="mean"):
168
+ domain = domain2folder[domain]
169
+ data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
170
+ if not os.path.exists(data_path):
171
+ return None
172
+ expand_df = pd.read_csv(data_path)
173
+ print(pattern)
174
+ if pattern is not None:
175
+ df = expand_df
176
+ freq, cols = eval(pattern)
177
+ pattern_str = ""
178
+ for col in cols:
179
+ col_name, col_val = col
180
+ try:
181
+ col_val = int(col_val)
182
+ except:
183
+ col_val = col_val
184
+ df = df[df[col_name] == col_val]
185
+ pattern_str += f"{col_name} = {col_val}, "
186
+ chart = alt.Chart(df).mark_bar().encode(
187
+ alt.X('model:N',
188
+ sort=alt.EncodingSortField(field=f'score', order='ascending', op="mean"),
189
+ axis=alt.Axis(labels=False, tickSize=0)), # no title, no label angle),
190
+ alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
191
+ alt.Color('model:N').legend(),
192
+ ).properties(
193
+ width=200,
194
+ height=100,
195
+ title=f"How do models perform on tasks where {pattern_str[:-2]} (N={freq})?"
196
+ )
197
+ return chart
198
+ else:
199
+ df = expand_df[(expand_df['model'].isin(models)) & (expand_df[category].isin(cat_options))]
200
+ if len(models) > 1:
201
+ chart = alt.Chart(df).mark_bar().encode(
202
+ alt.X('model:N',
203
+ sort=alt.EncodingSortField(field=f'score', order=order, op="mean"),
204
+ axis=alt.Axis(labels=False, tickSize=0, title=None)),
205
+ alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
206
+ alt.Color('model:N').legend(),
207
+ alt.Column(f'{category}:N', header=alt.Header(titleOrient='bottom', labelOrient='bottom'))
208
+ ).properties(
209
+ width=200,
210
+ height=100,
211
+ title=f"How do models perform across {category}?"
212
+ )
213
+ else:
214
+ chart = alt.Chart(df).mark_bar().encode(
215
+ alt.X(f'{category}:N', sort=alt.EncodingSortField(field=f'score', order=order, op="mean")), # no title, no label angle),
216
+ alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
217
+ alt.Color(f'{category}:N').legend(None),
218
+ ).properties(
219
+ width=200,
220
+ height=100,
221
+ title=f"How does {models[0]} perform across {category}?"
222
+ )
223
+ chart = chart.configure_title(fontSize=15, offset=5, orient='top', anchor='middle')
224
+ return chart
225
+
226
+
227
+ def plot(domain, partition, models, rank, k, threshold, baseline, order, category, cat_options):
228
+ domain = domain2folder[domain]
229
+ data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
230
+ expand_data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
231
+ # task_plan.reset_index(inplace=True)
232
+ if not os.path.exists(data_path) or not os.path.exists(expand_data_path):
233
+ return None
234
+ else:
235
+ merged_df = pd.read_csv(data_path)
236
+ merged_df.rename(columns={'llavav1.5-7b': 'llava15-7b', 'llavav1.5-13b': 'llava15-13b'}, inplace=True)
237
+ expand_df = pd.read_csv(expand_data_path)
238
+
239
+ df = merged_df
240
+
241
+ select_top = rank == "top"
242
+ # Model X is good / bad at
243
+ for model in models:
244
+ if baseline:
245
+ df = df[df[model] >= df[baseline]]
246
+ else:
247
+ if select_top:
248
+ df = df[df[model] >= threshold]
249
+ else:
250
+ df = df[df[model] <= threshold]
251
+ if not baseline:
252
+ df['mean score'] = df[models].mean(axis=1)
253
+ df = df.sort_values(by='mean score', ascending=False)
254
+ df = df.iloc[:k, :] if select_top else df.iloc[-k:, :]
255
+
256
+ task_ids = list(df.index)
257
+ if baseline:
258
+ models += [baseline]
259
+
260
+ chart_df = expand_df[expand_df['model'].isin(models)]
261
+ chart_df = chart_df[chart_df['task id'].isin(task_ids)]
262
+
263
+ if cat_options:
264
+ df = chart_df[chart_df[category].isin(cat_options)]
265
+ else:
266
+ df = chart_df
267
+ if baseline:
268
+ model_str = (', '.join(models) if len(models) > 1 else models[0])
269
+ phrase = f'{model_str} perform' if len(models) > 1 else f'{model_str} performs'
270
+ title = f"Are there any tasks where {phrase} better than {baseline} (by {category})?"
271
+
272
+ else:
273
+ title = f"What tasks are models {'best' if select_top else 'worst'} at by {category}?"
274
+
275
+ if len(models) > 1:
276
+ chart = alt.Chart(df).mark_bar().encode(
277
+ alt.X('model:N',
278
+ sort=alt.EncodingSortField(field=f'score', order=order, op="mean"),
279
+ axis=alt.Axis(labels=False, tickSize=0, title=None)),
280
+ alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
281
+ alt.Color('model:N').legend(),
282
+ alt.Column(f'{category}:N', header=alt.Header(titleOrient='bottom', labelOrient='bottom'))
283
+ ).properties(
284
+ width=200,
285
+ height=100,
286
+ title=title
287
+ )
288
+ else:
289
+ chart = alt.Chart(df).mark_bar().encode(
290
+ alt.X(f'{category}:N', sort=alt.EncodingSortField(field=f'score', order=order, op="mean")), # no title, no label angle),
291
+ alt.Y('mean(score):Q', scale=alt.Scale(zero=True)),
292
+ alt.Color(f'{category}:N').legend(None),
293
+ ).properties(
294
+ width=200,
295
+ height=100,
296
+ title=f"What tasks is model {models[0]} {'best' if select_top else 'worst'} at by {category}?"
297
+ )
298
+ chart = chart.configure_title(fontSize=15, offset=5, orient='top', anchor='middle')
299
+ return chart
300
+
301
+
302
+ def get_frequent_patterns(task_plan, scores):
303
+ find_frequent_patterns(k=10, df=task_plan, scores=scores)
304
+
305
+ def list_directories(path):
306
+ """List all directories within a given path."""
307
+ return [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
308
+
309
+
310
+ def update_category(domain, partition):
311
+ domain = domain2folder[domain]
312
+ data_path = f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl"
313
+ if os.path.exists(data_path):
314
+ data = pickle.load(open(data_path, 'rb'))
315
+ categories = list(data.columns)
316
+ category = gr.Dropdown(categories+["task id"], value=None, label="task metadata", interactive=True)
317
+ return category
318
+ else:
319
+ return gr.Dropdown([], value=None, label="task metadata")
320
+
321
+ def update_category2(domain, partition, existing_category):
322
+ domain = domain2folder[domain]
323
+ data_path = f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl"
324
+ if os.path.exists(data_path):
325
+ data = pickle.load(open(data_path, 'rb'))
326
+ categories = list(data.columns)
327
+ if existing_category and existing_category in categories:
328
+ categories.remove(existing_category)
329
+ category = gr.Dropdown(categories, value=None, label="Optional: second task metadata", interactive=True)
330
+ return category
331
+ else:
332
+ return gr.Dropdown([], value=None, label="task metadata")
333
+
334
+ def update_partition(domain):
335
+ domain = domain2folder[domain]
336
+ path = f"{BASE_DIR}/{domain}"
337
+ if os.path.exists(path):
338
+ partitions = list_directories(path)
339
+ return gr.Dropdown(partitions, value=partitions[0], label="task space of the following task generator")
340
+ else:
341
+ return gr.Dropdown([], value=None, label="task space of the following task generator")
342
+
343
+ def update_k(domain, partition, category=None):
344
+ domain = domain2folder[domain]
345
+ data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
346
+ if os.path.exists(data_path):
347
+ data = pd.read_csv(data_path)
348
+ max_k = len(data[category].unique()) if category and category != "task id" else len(data)
349
+ mid = max_k // 2
350
+ return gr.Slider(1, max_k, mid, step=1.0, label="k")
351
+ else:
352
+ return gr.Slider(1, 1, 1, step=1.0, label="k")
353
+
354
+ # def update_category_values(domain, partition, category):
355
+ # data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
356
+ # if os.path.exists(data_path) and category is not None:
357
+ # data = pd.read_csv(data_path)
358
+ # uni_cats = list(data[category].unique())
359
+ # return gr.Dropdown(uni_cats, multiselect=True, value=None, interactive=True, label="category values")
360
+ # else:
361
+ # return gr.Dropdown([], multiselect=True, value=None, interactive=False, label="category values")
362
+
363
+ # def update_category_values(domain, partition, models, rank, k, threshold, baseline, category):
364
+ # data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
365
+
366
+ # if not os.path.exists(data_path):
367
+ # return gr.Dropdown([], multiselect=True, value=None, interactive=False, label="category values")
368
+ # else:
369
+ # merged_df = pd.read_csv(data_path)
370
+ # merged_df.rename(columns={'llavav1.5-7b': 'llava15-7b', 'llavav1.5-13b': 'llava15-13b'}, inplace=True)
371
+
372
+ # df = merged_df
373
+
374
+ # select_top = rank == "top"
375
+ # # Model X is good / bad at
376
+ # for model in models:
377
+ # if baseline:
378
+ # df = df[df[model] >= df[baseline]]
379
+ # else:
380
+ # if select_top:
381
+ # df = df[df[model] >= threshold]
382
+ # else:
383
+ # df = df[df[model] <= threshold]
384
+ # if not baseline:
385
+ # df['mean score'] = df[models].mean(axis=1)
386
+ # df = df.sort_values(by='mean score', ascending=False)
387
+ # df = df.iloc[:k, :] if select_top else df.iloc[-k:, :]
388
+ # uni_cats = list(df[category].unique())
389
+ # return gr.Dropdown(uni_cats, multiselect=True, value=None, interactive=True, label="category values")
390
+
391
+
392
+ def update_tasks(domain, partition, find_pattern):
393
+ domain = domain2folder[domain]
394
+ if find_pattern == "yes":
395
+ k1 = gr.Slider(1, 10000, 10, step=1.0, label="k", interactive=True)
396
+ pattern = gr.Dropdown([], value=None, interactive=True, label="pattern")
397
+ category1 = gr.Dropdown([], value=None, interactive=False, label="task metadata")
398
+ return [k1, pattern, category1]
399
+ else:
400
+ k1 = gr.Slider(1, 10000, 10, step=1.0, label="k", interactive=False)
401
+ pattern = gr.Dropdown([], value=None, interactive=False, label="pattern")
402
+
403
+ data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
404
+ if os.path.exists(data_path):
405
+ data = pd.read_csv(data_path)
406
+ non_columns = MODELS + ['question', 'answer']
407
+ categories = [cat for cat in list(data.columns) if cat not in non_columns]
408
+ category1 = gr.Dropdown(categories, value=categories[0], interactive=True, label="task metadata")
409
+ else:
410
+ category1 = gr.Dropdown([], value=None, label="task metadata")
411
+ return [k1, pattern, category1]
412
+
413
+
414
+ def update_pattern(domain, partition, k):
415
+ domain = domain2folder[domain]
416
+ data_path = f"{BASE_DIR}/{domain}/{partition}/patterns.pkl"
417
+ if not os.path.exists(data_path):
418
+ return gr.Dropdown([], value=None, interactive=False, label="pattern")
419
+ else:
420
+ results = pickle.load(open(data_path, 'rb'))
421
+ patterns = results[0]
422
+ patterns = [str(p) for p in patterns]
423
+ print(patterns)
424
+ return gr.Dropdown(patterns[:k], value=None, interactive=True, label="pattern")
425
+
426
+ def update_threshold(domain, partition, baseline):
427
+ domain = domain2folder[domain]
428
+ print(baseline)
429
+ if baseline:
430
+ rank = gr.Radio(['top', 'bottom'], value='top', label="rank", interactive=False)
431
+ k = gr.Slider(1, 10000, 10, step=1.0, label="k", interactive=False)
432
+ threshold = gr.Slider(0, 1, 0.0, label="threshold", interactive=False)
433
+ return [rank, k, threshold]
434
+ else:
435
+ data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
436
+ if os.path.exists(data_path):
437
+ data = pd.read_csv(data_path)
438
+ max_k = len(data)
439
+ print(max_k)
440
+ k = gr.Slider(1, max_k, 10, step=1.0, label="k", interactive=True)
441
+ else:
442
+ k = gr.Slider(1, 1, 1, step=1.0, label="k")
443
+ rank = gr.Radio(['top', 'bottom'], value='top', label="rank", interactive=True)
444
+
445
+ threshold = gr.Slider(0, 1, 0.0, label="threshold", interactive=True)
446
+ return [rank, k, threshold]
447
+
448
+ def calc_surprisingness(model, scores, embeddings, k):
449
+ scores = scores[model].to_numpy()
450
+ sim = embeddings @ embeddings.T
451
+ # print("sim values:", sim.shape, sim)
452
+ indices = np.argsort(-sim)[:, :k]
453
+ # print("indices:", indices.shape, indices)
454
+ score_diff = scores[:, None] - scores[indices]
455
+ # print("score differences:", score_diff.shape, score_diff)
456
+ sim = sim[np.arange(len(scores))[:, None], indices]
457
+ # print("top10 sim:", sim.shape, sim)
458
+ all_surprisingness = score_diff * sim
459
+ # print("all surprisingness:", all_surprisingness.shape, all_surprisingness)
460
+ mean_surprisingness = np.mean(score_diff * sim, axis=1)
461
+ res = {'similarity': sim,
462
+ 'task index': indices,
463
+ 'score difference': score_diff,
464
+ 'all surprisingness': all_surprisingness,
465
+ 'mean surprisingness': mean_surprisingness
466
+ }
467
+ return res
468
+
469
+
470
+ def plot_surprisingness(domain, partition, model, rank, k, num_neighbors):
471
+ domain = domain2folder[domain]
472
+ # model = model[0]
473
+ model_str = model.replace("-", "_")
474
+
475
+ # sp_path = f"{BASE_DIR}/{domain}/{partition}/surprise_data.csv"
476
+ sp_pkl = f"{BASE_DIR}/{domain}/{partition}/{model_str}_surprise.pkl"
477
+ merged_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
478
+ if os.path.exists(sp_pkl) and os.path.exists(merged_path): # and not os.path.exists(sp_path)
479
+ # if os.path.exists(sp_path):
480
+ # sp_df = pd.read_csv(sp_path)
481
+ # # res = calc_surprisingness(model, scores, embeds, num_neighbors)
482
+ # # k = 10
483
+ # model = 'qwenvl'
484
+ # num_neighbors = 10
485
+ # if os.path.exists(sp_pkl):
486
+ res = pickle.load(open(sp_pkl, 'rb'))
487
+
488
+ total_num_task = res['task index'].shape[0]
489
+ all_records = []
490
+ for i in range(total_num_task):
491
+ mean_surprisingness = np.mean(res['all surprisingness'][i, :num_neighbors])
492
+ for j in range(num_neighbors):
493
+ neighbor_id = res['task index'][i, j]
494
+ score_diff = res['score difference'][i, j]
495
+ surprisingness = res['all surprisingness'][i, j]
496
+ similarity = res['similarity'][i, j]
497
+
498
+ record = {"task id": i,
499
+ "neighbor rank": j,
500
+ "neighbor id": neighbor_id,
501
+ "score difference": score_diff,
502
+ "surprisingness": surprisingness,
503
+ "mean surprisingness": mean_surprisingness,
504
+ "similarity": similarity
505
+ }
506
+ # print(record)
507
+ all_records.append(record)
508
+ sp_df = pd.DataFrame.from_records(all_records)
509
+ sp_df = sp_df.sort_values(by="mean surprisingness", ascending=False)
510
+
511
+ num_rows = k * num_neighbors
512
+ df = sp_df.iloc[:num_rows, :] if rank == "top" else sp_df.iloc[-num_rows:, :]
513
+ print(len(df))
514
+
515
+ df['is target'] = df.apply(lambda row: int(row['task id'] == row['neighbor id']), axis=1)
516
+
517
+ merged_df = pd.read_csv(merged_path)
518
+ for col in merged_df.columns:
519
+ df[col] = df.apply(lambda row: merged_df.iloc[int(row['neighbor id']), :][col], axis=1)
520
+
521
+ tooltips = ['neighbor id'] + ['image', 'question', 'answer', model]
522
+
523
+ print(df.head())
524
+ pts = alt.selection_point(encodings=['x'])
525
+ embeds = alt.Chart(df).mark_point(size=30, filled=True).encode(
526
+ alt.OpacityValue(0.5),
527
+ alt.X('x:Q', scale=alt.Scale(zero=False)),
528
+ alt.Y('y:Q', scale=alt.Scale(zero=False)),
529
+ alt.Color(f'{model}:Q'), #scale=alt.Scale(domain=[1, 0.5, 0], range=['blue', 'white', 'red'], interpolate='rgb')
530
+ alt.Size("is target:N", legend=None, scale=alt.Scale(domain=[0, 1], range=[300, 500])),
531
+ alt.Shape("is target:N", legend=None, scale=alt.Scale(domain=[0, 1], range=['circle', 'triangle'])),
532
+ alt.Order("is target:N"),
533
+ tooltip=tooltips,
534
+ ).properties(
535
+ width=400,
536
+ height=400,
537
+ title=f"What are the tasks {model} is surprisingly {'good' if rank == 'top' else 'bad'} at compared to {num_neighbors} similar tasks?"
538
+ ).transform_filter(
539
+ pts
540
+ )
541
+
542
+ bar = alt.Chart(df).mark_bar().encode(
543
+ alt.Y('mean(mean surprisingness):Q'),
544
+ alt.X('task id:N', sort=alt.EncodingSortField(field='mean surprisingness', order='descending')),
545
+ color=alt.condition(pts, alt.ColorValue("steelblue"), alt.ColorValue("grey")), #
546
+ ).add_params(pts).properties(
547
+ width=400,
548
+ height=200,
549
+ )
550
+
551
+ chart = alt.hconcat(
552
+ bar,
553
+ embeds
554
+ ).resolve_legend(
555
+ color="independent",
556
+ size="independent"
557
+ ).configure_title(
558
+ fontSize=20
559
+ ).configure_legend(
560
+ labelFontSize=10,
561
+ titleFontSize=10,
562
+ )
563
+ return chart
564
+ else:
565
+ print(sp_pkl, merged_path)
566
+ return None
567
+
568
+
569
+
570
+ def plot_task_distribution(domain, partition, category):
571
+ domain = domain2folder[domain]
572
+ task_plan = pickle.load(open(f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl", "rb"))
573
+ task_plan.reset_index(inplace=True)
574
+ col_name = category
575
+ task_plan_cnt = task_plan.groupby(col_name)['index'].count().reset_index()
576
+ task_plan_cnt.rename(columns={'index': 'count'}, inplace=True)
577
+ task_plan_cnt['frequency (%)'] = round(task_plan_cnt['count'] / len(task_plan) * 100, 2)
578
+ task_plan_cnt.head()
579
+
580
+ base = alt.Chart(task_plan_cnt).encode(
581
+ alt.Theta("count:Q").stack(True),
582
+ alt.Color(f"{col_name}:N").legend(),
583
+ tooltip=[col_name, 'count', 'frequency (%)']
584
+ )
585
+ pie = base.mark_arc(outerRadius=120)
586
+ return pie
587
+
588
+ def plot_all(domain, partition, models, category1, category2, agg):
589
+ domain = domain2folder[domain]
590
+ data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
591
+ if not os.path.exists(data_path):
592
+ return None
593
+ expand_df = pd.read_csv(data_path)
594
+ chart_df = expand_df[expand_df['model'].isin(models)]
595
+ if category2:
596
+
597
+ color_val = f'{agg}(score):Q'
598
+
599
+ chart = alt.Chart(chart_df).mark_rect().encode(
600
+ alt.X(f'{category1}:N', sort=alt.EncodingSortField(field='score', order='ascending', op=agg)),
601
+ alt.Y(f'{category2}:N', sort=alt.EncodingSortField(field='score', order='descending', op=agg)), # no title, no label angle),
602
+ alt.Color(color_val),
603
+ alt.Tooltip('score', aggregate=agg, title=f"{agg} score"),
604
+ ).properties(
605
+ width=800,
606
+ height=200,
607
+ )
608
+ else:
609
+ category = "index" if category1 == "task id" else category1
610
+ # cat_options = list(chart_df[category].unique())
611
+ # cat_options = cat_options[:5]
612
+ y_val = f'{agg}(score):Q'
613
+ df = chart_df
614
+ # df = chart_df[chart_df[category].isin(cat_options)]
615
+ if len(models) > 1:
616
+ chart = alt.Chart(df).mark_bar().encode(
617
+ alt.X('model:N',
618
+ sort=alt.EncodingSortField(field=f'score', order='ascending', op=agg),
619
+ axis=alt.Axis(labels=False, tickSize=0, title=None)),
620
+ alt.Y(y_val, scale=alt.Scale(zero=True)),
621
+ alt.Color('model:N').legend(),
622
+ alt.Column(f'{category}:N', header=alt.Header(titleOrient='bottom', labelOrient='bottom'))
623
+ ).properties(
624
+ width=200,
625
+ height=100,
626
+ title=f"How do models perform across {category}?"
627
+ )
628
+ else:
629
+ chart = alt.Chart(df).mark_bar().encode(
630
+ alt.X(f'{category}:N', sort=alt.EncodingSortField(field=f'score', order='ascending', op=agg)), # no title, no label angle),
631
+ alt.Y(y_val, scale=alt.Scale(zero=True)),
632
+ alt.Color(f'{category}:N').legend(None),
633
+ ).properties(
634
+ width=200,
635
+ height=100,
636
+ title=f"How does {models[0]} perform across {category}?"
637
+ )
638
+ chart = chart.configure_title(fontSize=20, offset=5, orient='top', anchor='middle').configure_axis(
639
+ labelFontSize=20,
640
+ titleFontSize=20,
641
+ ).configure_legend(
642
+ labelFontSize=15,
643
+ titleFontSize=15,
644
+ )
645
+ return chart
646
+
647
+ def update_widgets(domain, partition, category, query_type):
648
+ domain = domain2folder[domain]
649
+ data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
650
+ if not os.path.exists(data_path):
651
+ print("here?")
652
+ return [None] * 11
653
+ df = pd.read_csv(data_path)
654
+ max_k = len(df[category].unique()) if category and category != "task id" else len(df)
655
+
656
+ widgets = []
657
+
658
+ if query_type == "top k":
659
+ # aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
660
+ rank = gr.Radio(['top', 'bottom'], value='top', label=" ", interactive=True, visible=True)
661
+ k = gr.Slider(1, max_k, max_k // 2, step=1.0, label="k", interactive=True, visible=True)
662
+ model = gr.Dropdown(MODELS, value=MODELS, label="of model(s)'", multiselect=True, interactive=True, visible=True)
663
+ # model_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate", interactive=True, visible=True)
664
+ model_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
665
+
666
+ baseline = gr.Dropdown(MODELS, value=None, label="baseline", visible=False)
667
+ direction = gr.Radio(['above', 'below'], value='above', label=" ", visible=False)
668
+ threshold = gr.Slider(0, 1, 0.0, label="threshold", visible=False)
669
+ baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="baseline aggregate", visible=False)
670
+ md1 = gr.Markdown(r"<h2>ranked by the </h2>")
671
+ md2 = gr.Markdown(r"<h2>accuracy</h2>")
672
+ md3 = gr.Markdown(r"")
673
+
674
+ elif query_type == "threshold":
675
+
676
+ # aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task aggregate", interactive=True, visible=True)
677
+ # aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
678
+ model = gr.Dropdown(MODELS, value=MODELS[0], label="of model(s)'", multiselect=True, interactive=True, visible=True)
679
+ direction = gr.Radio(['above', 'below'], value='above', label=" ", interactive=True, visible=True)
680
+ threshold = gr.Slider(0, 1, 0.0, label="threshold", interactive=True, visible=True)
681
+ # model_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate", interactive=True, visible=True)
682
+ model_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
683
+
684
+ rank = gr.Radio(['top', 'bottom'], value='top', label=" ", visible=False)
685
+ k = gr.Slider(1, max_k, max_k // 2, step=1.0, label="k", visible=False)
686
+ baseline = gr.Dropdown(MODELS, value=None, label="baseline", visible=False)
687
+ baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="baseline aggregate", visible=False)
688
+ md1 = gr.Markdown(r"<h2>where the</h2>")
689
+ md2 = gr.Markdown(r"<h2>accuracy is</h2>")
690
+ md3 = gr.Markdown(r"")
691
+
692
+ elif query_type == "model comparison":
693
+
694
+ model = gr.Dropdown(MODELS, value=MODELS[0], label="of model(s)' accuracy", multiselect=True, interactive=True, visible=True)
695
+ baseline = gr.Dropdown(MODELS, value=None, label="of baseline(s)' accuracy", multiselect=True, interactive=True, visible=True)
696
+ direction = gr.Radio(['above', 'below'], value='above', label=" ", interactive=True, visible=True)
697
+ threshold = gr.Slider(0, 1, 0.0, label="threshold", interactive=True, visible=True)
698
+ model_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
699
+ # baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate (over baselines)", interactive=True, visible=True)
700
+ baseline_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
701
+
702
+ # aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task aggregate", interactive=True, visible=False)
703
+ rank = gr.Radio(['top', 'bottom'], value='top', label=" ", visible=False)
704
+ k = gr.Slider(1, max_k, max_k // 2, step=1.0, label="k", visible=False)
705
+ md1 = gr.Markdown(r"<h2>where the difference between the </h2>")
706
+ md2 = gr.Markdown(r"<h2>is </h2>")
707
+ md3 = gr.Markdown(r"<h2>and the</h2>")
708
+
709
+ elif query_type == "model debugging":
710
+ model = gr.Dropdown(MODELS, value=MODELS[0], label="model's", multiselect=False, interactive=True, visible=True)
711
+
712
+ # aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", visible=False)
713
+ baseline = gr.Dropdown(MODELS, value=None, label="baseline", visible=False)
714
+ direction = gr.Radio(['above', 'below'], value='above', label=" ", visible=False)
715
+ threshold = gr.Slider(0, 1, 0.0, label="threshold", visible=False)
716
+ rank = gr.Radio(['top', 'bottom'], value='top', label=" ", visible=False)
717
+ k = gr.Slider(1, max_k, max_k // 2, step=1.0, label="k", visible=False)
718
+ model_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate (over models)", visible=False)
719
+ baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="baseline aggregate", visible=False)
720
+ md1 = gr.Markdown(r"<h2>where </h2>")
721
+ md2 = gr.Markdown(r"<h2>mean accuracy is below its overall mean accuracy by one standard deviation</h2>")
722
+ md3 = gr.Markdown(r"")
723
+ else:
724
+ widgets = [None] * 11
725
+ widgets = [rank, k, direction, threshold, model, model_aggregate, baseline, baseline_aggregate, md1, md2, md3]
726
+
727
+ return widgets
728
+
729
+ def select_tasks(domain, partition, category, query_type, task_agg, models, model_agg, rank, k, direction, threshold, baselines, baseline_agg):
730
+ domain = domain2folder[domain]
731
+ data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
732
+ merged_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
733
+
734
+ if not os.path.exists(data_path) or not os.path.exists(merged_path):
735
+ return gr.DataFrame(None)
736
+ df = pd.read_csv(data_path)
737
+ merged_df = pd.read_csv(merged_path)
738
+ task_plan = pickle.load(open(f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl", 'rb'))
739
+ task_plan.reset_index(inplace=True)
740
+ if not category or category == "task id":
741
+ category = 'index'
742
+
743
+ if query_type == "top k":
744
+ df = df[df['model'].isin(models)]
745
+ df = df.groupby([category, 'model'])['score'].agg(task_agg).reset_index()
746
+ df = df.groupby([category])['score'].agg(model_agg).reset_index()
747
+ df = df.sort_values(by='score', ascending=False)
748
+ if rank == "bottom":
749
+ df = df.iloc[-k:, :]
750
+ else:
751
+ df = df.iloc[:k, :]
752
+ elif query_type == "threshold":
753
+ df = df[df['model'].isin(models)]
754
+ df = df.groupby([category, 'model'])['score'].agg(task_agg).reset_index()
755
+ df = df.groupby([category])['score'].agg(model_agg).reset_index()
756
+ if direction == "below":
757
+ df = df[df['score'] <= threshold]
758
+ else:
759
+ df = df[df['score'] >= threshold]
760
+ elif query_type == "model comparison":
761
+ # df = merged_df
762
+ # df.reset_index(inplace=True)
763
+ # df = df.groupby([category])[[model, baseline]].agg(task_agg).reset_index()
764
+ # df = df[(df[model] - df[baseline] > threshold)]
765
+ df_baseline = deepcopy(df)
766
+
767
+ df = df[df['model'].isin(models)]
768
+ df = df.groupby([category, 'model'])['score'].agg(task_agg).reset_index()
769
+ df = df.groupby([category])['score'].agg(model_agg).reset_index()
770
+ model_str = ', '.join(models)
771
+ exp_score_id = f'{model_agg}({model_str})' if len(models) > 1 else model_str
772
+ df = df.sort_values(by=category)
773
+
774
+ df_baseline = df_baseline[df_baseline['model'].isin(baselines)]
775
+ df_baseline = df_baseline.groupby([category, 'model'])['score'].agg(task_agg).reset_index()
776
+ df_baseline = df_baseline.groupby([category])['score'].agg(baseline_agg).reset_index()
777
+ model_str = ', '.join(baselines)
778
+ baseline_score_id = f'{baseline_agg}({model_str})' if len(baselines) > 1 else model_str
779
+ df_baseline = df_baseline.sort_values(by=category)
780
+
781
+
782
+ df.rename(columns={'score': exp_score_id}, inplace=True)
783
+ df_baseline.rename(columns={'score': baseline_score_id}, inplace=True)
784
+ df = pd.merge(df, df_baseline, on=category)
785
+ df = df[(df[exp_score_id] - df[baseline_score_id] > threshold)]
786
+
787
+ elif query_type == "model debugging":
788
+ model = models
789
+ print(models)
790
+ avg_acc = merged_df[model].mean()
791
+ std = merged_df[model].std()
792
+ t = avg_acc - std
793
+ df = df[df['model'] == model]
794
+ df = df.groupby(['model', category])['score'].agg(task_agg).reset_index()
795
+ df = df[df['score'] < t]
796
+ df['mean'] = round(avg_acc, 4)
797
+ df['std'] = round(std, 4)
798
+
799
+ print(df.head())
800
+ if category == 'index':
801
+ task_attrs = list(df[category])
802
+ selected_tasks = task_plan[task_plan[category].isin(task_attrs)]
803
+
804
+ if len(selected_tasks) == 0:
805
+ return gr.DataFrame(None, label="There is no such task.")
806
+
807
+ if query_type == "model comparison" and (models and baselines):
808
+ # selected_tasks[model] = selected_tasks.apply(lambda row: df[df['index'] == row['index']][model].values[0], axis=1)
809
+ # selected_tasks[baseline] = selected_tasks.apply(lambda row: df[df['index'] == row['index']][baseline].values[0], axis=1)
810
+ selected_tasks[exp_score_id] = selected_tasks.apply(lambda row: df[df['index'] == row['index']][exp_score_id].values[0], axis=1)
811
+ selected_tasks[baseline_score_id] = selected_tasks.apply(lambda row: df[df['index'] == row['index']][baseline_score_id].values[0], axis=1)
812
+ else:
813
+ selected_tasks['score'] = selected_tasks.apply(lambda row: df[df['index'] == row['index']]['score'].values[0], axis=1)
814
+
815
+ print(selected_tasks.head())
816
+ return gr.DataFrame(selected_tasks, label=f"There are {len(selected_tasks)} (out of {len(task_plan)}) tasks in total.")
817
+ else:
818
+ if len(df) == 0:
819
+ return gr.DataFrame(None, label=f"There is no such {category}.")
820
+ else:
821
+ return gr.DataFrame(df, label=f"The total number of such {category} is {len(df)}.")
822
+
823
+
824
+ def find_patterns(selected_tasks, num_patterns, models, baselines, model_agg, baseline_agg):
825
+ if len(selected_tasks) == 0:
826
+ return gr.DataFrame(None)
827
+ print(selected_tasks.head())
828
+ if 'score' in selected_tasks:
829
+ scores = selected_tasks['score']
830
+ # elif model in selected_tasks:
831
+ # scores = selected_tasks[model]
832
+ else:
833
+ scores = None
834
+ print(scores)
835
+
836
+ model_str = ', '.join(models)
837
+ exp_score_id = f'{model_agg}({model_str})' if len(models) > 1 else model_str
838
+ if baselines:
839
+ baseline_str = ', '.join(baselines)
840
+ baseline_score_id = f'{baseline_agg}({baseline_str})' if len(baselines) > 1 else baseline_str
841
+
842
+ tasks_only = selected_tasks
843
+ all_score_cols = ['score', exp_score_id]
844
+ if baselines:
845
+ all_score_cols += [baseline_score_id]
846
+ for name in all_score_cols:
847
+ if name in selected_tasks:
848
+ tasks_only = tasks_only.drop(name, axis=1)
849
+ results = find_frequent_patterns(k=num_patterns, df=tasks_only, scores=scores)
850
+ records = []
851
+ if scores is not None:
852
+ patterns, scores = results[0], results[1]
853
+ for pattern, score in zip(patterns, scores):
854
+ pattern_str = ""
855
+ for t in pattern[1]:
856
+ col_name, col_val = t
857
+ pattern_str += f"{col_name} = {col_val}, "
858
+
859
+ record = {'pattern': pattern_str[:-2], 'count': pattern[0], 'score': score} #{model}
860
+ records.append(record)
861
+ else:
862
+ patterns = results
863
+ for pattern in patterns:
864
+ pattern_str = ""
865
+ for t in pattern[1]:
866
+ col_name, col_val = t
867
+ pattern_str += f"{col_name} = {col_val}, "
868
+
869
+ record = {'pattern': pattern_str[:-2], 'count': pattern[0]}
870
+ records.append(record)
871
+
872
+ df = pd.DataFrame.from_records(records)
873
+ return gr.DataFrame(df)
874
+
875
+ def visualize_task_distribution(selected_tasks, col_name, model1, model2):
876
+ if not col_name:
877
+ return None
878
+ task_plan_cnt = selected_tasks.groupby(col_name)['index'].count().reset_index()
879
+ task_plan_cnt.rename(columns={'index': 'count'}, inplace=True)
880
+ task_plan_cnt['frequency (%)'] = round(task_plan_cnt['count'] / len(selected_tasks) * 100, 2)
881
+ print(task_plan_cnt.head())
882
+
883
+ tooltips = [col_name, 'count', 'frequency (%)']
884
+ base = alt.Chart(task_plan_cnt).encode(
885
+ alt.Theta("count:Q").stack(True),
886
+ alt.Color(f"{col_name}:N").legend(),
887
+ tooltip=tooltips
888
+ )
889
+ pie = base.mark_arc(outerRadius=120)
890
+
891
+ return pie
892
+
893
+ def plot_performance_for_selected_tasks(domain, partition, df, query_type, models, baselines, select_category, vis_category, task_agg, model_agg, baseline_agg, rank, direction, threshold):
894
+ domain = domain2folder[domain]
895
+ task_agg = "mean"
896
+ data_path = f"{BASE_DIR}/{domain}/{partition}/expanded_data.csv"
897
+ mereged_data_path = f"{BASE_DIR}/{domain}/{partition}/merged_data.csv"
898
+
899
+ if not os.path.exists(data_path) or not os.path.exists(mereged_data_path) or len(df) == 0:
900
+ return None
901
+
902
+ select_tasks = select_category == "task id" and vis_category
903
+ if select_tasks: # select tasks
904
+ y_val = f'{task_agg}(score):Q'
905
+ else: # select task categories
906
+ y_val = f'score:Q'
907
+
908
+ if select_category == "task id":
909
+ select_category = "index"
910
+ print(df.head())
911
+ if query_type == "model comparison":
912
+ # re-format the data for plotting
913
+ model_str = ', '.join(models)
914
+ exp_score_id = f'{model_agg}({model_str})' if len(models) > 1 else model_str
915
+ baseline_str = ', '.join(baselines)
916
+ baseline_score_id = f'{baseline_agg}({baseline_str})' if len(baselines) > 1 else baseline_str
917
+ # other_cols = list(df.columns)
918
+ # other_cols.remove(select_category)
919
+ print(exp_score_id, baseline_score_id)
920
+ df = df.melt(id_vars=[select_category], value_vars=[exp_score_id, baseline_score_id])
921
+ df.rename(columns={'variable': 'model', 'value': 'score'}, inplace=True)
922
+ print(df.head())
923
+
924
+ if select_tasks:
925
+ merged_df = pd.read_csv(mereged_data_path)
926
+ df[vis_category] = df.apply(lambda row: merged_df[merged_df.index == row['index']][vis_category].values[0], axis=1)
927
+
928
+ num_columns = len(df['model'].unique()) * len(df[f'{vis_category}'].unique())
929
+ chart = alt.Chart(df).mark_bar().encode(
930
+ alt.X('model:N',
931
+ sort=alt.EncodingSortField(field=f'score', order='descending', op=task_agg),
932
+ axis=alt.Axis(labels=False, tickSize=0, title=None)),
933
+ alt.Y(y_val, scale=alt.Scale(zero=True), title="accuracy"),
934
+ alt.Color('model:N').legend(),
935
+ alt.Column(f'{vis_category}:N', header=alt.Header(titleOrient='bottom', labelOrient='bottom', labelFontSize=20, titleFontSize=20,))
936
+ ).properties(
937
+ width=num_columns * 30,
938
+ height=200,
939
+ title=f"How do models perform by {vis_category}?"
940
+ )
941
+ print(num_columns * 50)
942
+ else:
943
+ if query_type == "model debugging":
944
+ y_title = "accuracy"
945
+ plot_title = f"{models} performs worse than its (mean - std) on these {vis_category}s"
946
+ models = [models]
947
+ else:
948
+ model_str = ', '.join(models)
949
+ y_title = f"{model_agg} accuracy" if len(models) > 0 else "accuracy"
950
+ suffix = f"on these tasks (by {vis_category})" if select_category == "index" else f"on these {vis_category}s"
951
+ if query_type == "top k":
952
+ plot_title = f"The {model_agg} accuracy of {model_str} is the {'highest' if rank == 'top' else 'lowest'} " + suffix
953
+ elif query_type == "threshold":
954
+ plot_title = f"The {model_agg} accuracy of {model_str} is {direction} {threshold} " + suffix
955
+
956
+ if select_tasks:
957
+ expand_df = pd.read_csv(data_path)
958
+ task_ids = list(df['index'].unique())
959
+
960
+ # all_models = (models + baselines) if baselines else models
961
+ df = expand_df[(expand_df['model'].isin(models)) & (expand_df['task id'].isin(task_ids))]
962
+
963
+ num_columns = len(df[f'{vis_category}'].unique())
964
+ chart = alt.Chart(df).mark_bar().encode(
965
+ alt.X(f'{vis_category}:N', sort=alt.EncodingSortField(field=f'score', order='ascending', op=task_agg), axis=alt.Axis(labelAngle=-20)), # no title, no label angle),
966
+ alt.Y(y_val, scale=alt.Scale(zero=True), title=y_title),
967
+ alt.Color(f'{vis_category}:N').legend(None),
968
+ ).properties(
969
+ width=num_columns * 30,
970
+ height=200,
971
+ title=plot_title
972
+ )
973
+
974
+ chart = chart.configure_title(fontSize=20, offset=5, orient='top', anchor='middle').configure_axis(
975
+ labelFontSize=20,
976
+ titleFontSize=20,
977
+ ).configure_legend(
978
+ labelFontSize=20,
979
+ titleFontSize=20,
980
+ labelLimit=200,
981
+ )
982
+ return chart
983
+
984
+ def sync_vis_category(domain, partition, category):
985
+ domain = domain2folder[domain]
986
+ if category and category != "task id":
987
+ return [gr.Dropdown([category], value=category, label="by task metadata", interactive=False), gr.Dropdown([category], value=category, label="by task metadata", interactive=False)]
988
+ else:
989
+ data_path = f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl"
990
+ if os.path.exists(data_path):
991
+ data = pickle.load(open(data_path, 'rb'))
992
+ categories = list(data.columns)
993
+ return [gr.Dropdown(categories, value=categories[0], label="by task metadata", interactive=True), gr.Dropdown(categories, value=categories[0], label="by task metadata", interactive=True)]
994
+ else:
995
+ return [None, None]
996
+
997
+ def hide_fpm_and_dist_components(domain, partition, category):
998
+ domain = domain2folder[domain]
999
+ print(category)
1000
+ if category and category != "task id":
1001
+ num_patterns = gr.Slider(1, 100, 50, step=1.0, label="number of patterns", visible=False)
1002
+ btn_pattern = gr.Button(value="Find patterns among tasks", visible=False)
1003
+
1004
+ table = gr.DataFrame({}, height=250, visible=False)
1005
+ dist_chart = gr.Plot(visible=False)
1006
+
1007
+ col_name = gr.Dropdown([], value=None, label="by task metadata", visible=False)
1008
+ btn_dist = gr.Button(value="Visualize task distribution", visible=False)
1009
+ else:
1010
+ data_path = f"{BASE_DIR}/{domain}/{partition}/task_plan.pkl"
1011
+ if os.path.exists(data_path):
1012
+ data = pickle.load(open(data_path, 'rb'))
1013
+ categories = list(data.columns)
1014
+ col_name = gr.Dropdown(categories, value=categories[0], label="by task metadata", interactive=True, visible=True)
1015
+ else:
1016
+ col_name = gr.Dropdown([], value=None, label="by task metadata", interactive=True, visible=True)
1017
+
1018
+ num_patterns = gr.Slider(1, 100, 50, step=1.0, label="number of patterns", interactive=True, visible=True)
1019
+ btn_pattern = gr.Button(value="Find patterns among tasks", interactive=True, visible=True)
1020
+
1021
+ table = gr.DataFrame({}, height=250, interactive=True, visible=True)
1022
+ dist_chart = gr.Plot(visible=True)
1023
+
1024
+ btn_dist = gr.Button(value="Visualize task distribution", interactive=True, visible=True)
1025
+ return [num_patterns, btn_pattern, table, col_name, btn_dist, dist_chart]
1026
+
1027
+
1028
+
1029
+ # domains = list_directories(BASE_DIR)
1030
+ theme = gr.Theme.from_hub('sudeepshouche/minimalist')
1031
+ theme.font = [gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"] # gr.themes.GoogleFont("Source Sans Pro") # [gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]
1032
+ theme.text_size = gr.themes.sizes.text_lg
1033
+ # theme = theme.set(font=)
1034
+
1035
+ demo = gr.Blocks(theme=theme, title="TaskVerse-UI") #
1036
+ with demo:
1037
+ with gr.Row():
1038
+ with gr.Column(scale=1):
1039
+ gr.Markdown(
1040
+ r""
1041
+ )
1042
+ with gr.Column(scale=1):
1043
+ gr.Markdown(
1044
+ r"<h1>Welcome to TaskVerse-UI! </h1>"
1045
+ )
1046
+ with gr.Column(scale=1):
1047
+ gr.Markdown(
1048
+ r""
1049
+ )
1050
+
1051
+ with gr.Tab("📊 Overview"):
1052
+ gr.Markdown(
1053
+ r"<h2>📊 Visualize the overall task distribution and model performance </h2>"
1054
+ )
1055
+
1056
+ with gr.Row():
1057
+ domain = gr.Radio(domains, label="scenario", scale=2)
1058
+ partition = gr.Dropdown([], value=None, label="task space of the following task generator", scale=1)
1059
+ # domain.change(fn=update_partition, inputs=domain, outputs=partition)
1060
+
1061
+
1062
+ gr.Markdown(
1063
+ r"<h2>Overall task metadata distribution</h2>"
1064
+ )
1065
+
1066
+ with gr.Row():
1067
+ category = gr.Dropdown([], value=None, label="task metadata")
1068
+ partition.change(fn=update_category, inputs=[domain, partition], outputs=category)
1069
+ with gr.Row():
1070
+ output = gr.Plot()
1071
+ with gr.Row():
1072
+ btn = gr.Button(value="Plot")
1073
+ btn.click(plot_task_distribution, [domain, partition, category], output)
1074
+
1075
+ gr.Markdown(
1076
+ r"<h2>Models' overall performance by task metadata</h2>"
1077
+ )
1078
+ with gr.Row():
1079
+ with gr.Column(scale=2):
1080
+ models = gr.CheckboxGroup(MODELS, label="model(s)", value=MODELS)
1081
+ with gr.Column(scale=1):
1082
+ aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="aggregate models' accuracy by")
1083
+ with gr.Row():
1084
+ # with gr.Column(scale=1):
1085
+ category1 = gr.Dropdown([], value=None, label="task metadata", interactive=True)
1086
+ category2 = gr.Dropdown([], value=None, label="Optional: second task metadata", interactive=True)
1087
+ partition.change(fn=update_category, inputs=[domain, partition], outputs=category1)
1088
+ category1.change(fn=update_category2, inputs=[domain, partition, category1], outputs=category2)
1089
+ domain.change(fn=update_partition_and_models, inputs=domain, outputs=[partition, models])
1090
+ with gr.Row():
1091
+ output = gr.Plot()
1092
+ with gr.Row():
1093
+ btn = gr.Button(value="Plot")
1094
+ btn.click(plot_all, [domain, partition, models, category1, category2, aggregate], output)
1095
+ # gr.Examples(["hello", "bonjour", "merhaba"], input_textbox)
1096
+
1097
+
1098
+ with gr.Tab("✨ Embedding"):
1099
+ gr.Markdown(
1100
+ r"<h2>✨ Visualize the tasks' embeddings in the 2D space </h2>"
1101
+ )
1102
+ with gr.Row():
1103
+ domain2 = gr.Radio(domains, label="scenario", scale=2)
1104
+ # domain = gr.Dropdown(domains, value=domains[0], label="scenario")
1105
+ partition2 = gr.Dropdown([], value=None, label="task space of the following task generator", scale=1)
1106
+ category2 = gr.Dropdown([], value=None, label="colored by task metadata", scale=1)
1107
+ domain2.change(fn=update_partition, inputs=domain2, outputs=partition2)
1108
+ partition2.change(fn=update_category, inputs=[domain2, partition2], outputs=category2)
1109
+
1110
+ with gr.Row():
1111
+ output2 = gr.Plot()
1112
+ with gr.Row():
1113
+ btn = gr.Button(value="Run")
1114
+ btn.click(plot_embedding, [domain2, partition2, category2], output2)
1115
+
1116
+
1117
+ with gr.Tab("❓ Query"):
1118
+ gr.Markdown(
1119
+ r"<h2>❓ Find out the answers to your queries by finding and visualizing the relevant tasks and models' performance </h2>"
1120
+ )
1121
+ with gr.Row(equal_height=True):
1122
+ domain = gr.Radio(domains, label="scenario", scale=2)
1123
+ partition = gr.Dropdown([], value=None, label="task space of the following task generator", scale=1)
1124
+ with gr.Row():
1125
+ query1 = "top k"
1126
+ query2 = "threshold"
1127
+ query3 = "model debugging"
1128
+ query4 = "model comparison"
1129
+ query_type = gr.Radio([query1, query2, query3, query4], value="top k", label=r"query type")
1130
+ with gr.Row():
1131
+ with gr.Accordion("See more details about the query type"):
1132
+ gr.Markdown(
1133
+ r"<ul><li>Top k: Find the k tasks or task metadata that the model(s) perform the best or worst on</li><li>Threshold: Find the tasks or task metadata where the model(s)' performance is greater or lower than a given threshold t</li><li>Model debugging: Find the tasks or task metadata where a model performs significantly worse than its average performance (by one standard deviation)</li><li>Model comparison: Find the tasks or task metadata where some model(s) perform better or worse than the baseline(s) by a given threshold t</li></ul>"
1134
+ )
1135
+
1136
+ with gr.Row():
1137
+ gr.Markdown(r"<h2>Help me find the</h2>")
1138
+ with gr.Row(equal_height=True):
1139
+ # with gr.Column(scale=1):
1140
+ rank = gr.Radio(['top', 'bottom'], value='top', label=" ", interactive=True, visible=True)
1141
+ # with gr.Column(scale=2):
1142
+ k = gr.Slider(1, 10, 5 // 2, step=1.0, label="k", interactive=True, visible=True)
1143
+ # with gr.Column(scale=2):
1144
+ category = gr.Dropdown([], value=None, label="tasks / task metadata", interactive=True)
1145
+
1146
+ with gr.Row():
1147
+ md1 = gr.Markdown(r"<h2>ranked by the </h2>")
1148
+
1149
+ with gr.Row(equal_height=True):
1150
+ # with gr.Column(scale=1, min_width=100):
1151
+ # model_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
1152
+ model_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True, scale=1)
1153
+ # with gr.Column(scale=8):
1154
+ model = gr.Dropdown(MODELS, value=MODELS, label="of model(s)", multiselect=True, interactive=True, visible=True, scale=2)
1155
+ # with gr.Column(scale=1, min_width=100):
1156
+ # aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True, scale=1)
1157
+ with gr.Row():
1158
+ md3 = gr.Markdown(r"")
1159
+ with gr.Row(equal_height=True):
1160
+ baseline_aggregate = gr.Dropdown(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=False, scale=1)
1161
+ baseline = gr.Dropdown(MODELS, value=None, label="of baseline(s)'", visible=False, scale=2)
1162
+ # aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label=" ", interactive=True, visible=True)
1163
+ # with gr.Column(scale=1, min_width=50):
1164
+ with gr.Row():
1165
+ md2 = gr.Markdown(r"<h2>accuracy</h2>")
1166
+
1167
+ with gr.Row():
1168
+ # baseline_aggregate = gr.Radio(['mean', 'median', 'min', 'max'], value="mean", label="task category aggregate (over baselines)", visible=False)
1169
+ direction = gr.Radio(['above', 'below'], value='above', label=" ", visible=False)
1170
+ threshold = gr.Slider(0, 1, 0.0, label="threshold", visible=False)
1171
+
1172
+ widgets = [rank, k, direction, threshold, model, model_aggregate, baseline, baseline_aggregate, md1, md2, md3]
1173
+ partition.change(fn=update_category, inputs=[domain, partition], outputs=category)
1174
+ query_type.change(update_widgets, [domain, partition, category, query_type], widgets)
1175
+ domain.change(fn=update_partition_and_models_and_baselines, inputs=domain, outputs=[partition, model, baseline])
1176
+ with gr.Row():
1177
+ df = gr.DataFrame({}, height=200)
1178
+ btn = gr.Button(value="Find tasks / task metadata")
1179
+ btn.click(select_tasks, [domain, partition, category, query_type, aggregate, model, model_aggregate, rank, k, direction, threshold, baseline, baseline_aggregate], df)
1180
+
1181
+ with gr.Row():
1182
+ plot = gr.Plot()
1183
+ with gr.Row():
1184
+ col_name2 = gr.Dropdown([], value=None, label="by task metadata", interactive=True)
1185
+ partition.change(fn=update_category, inputs=[domain, partition], outputs=col_name2)
1186
+ btn_plot = gr.Button(value="Plot model performance", interactive=True)
1187
+ btn_plot.click(plot_performance_for_selected_tasks, [domain, partition, df, query_type, model, baseline, category, col_name2, aggregate, model_aggregate, baseline_aggregate, rank, direction, threshold], plot)
1188
+
1189
+ with gr.Row():
1190
+ dist_chart = gr.Plot()
1191
+ with gr.Row():
1192
+ col_name = gr.Dropdown([], value=None, label="by task metadata", interactive=True)
1193
+ partition.change(fn=update_category, inputs=[domain, partition], outputs=col_name)
1194
+ btn_dist = gr.Button(value="Visualize task distribution", interactive=True)
1195
+ btn_dist.click(visualize_task_distribution, [df, col_name, model, baseline], dist_chart)
1196
+
1197
+ with gr.Row():
1198
+ table = gr.DataFrame({}, height=250)
1199
+ with gr.Row():
1200
+ num_patterns = gr.Slider(1, 100, 50, step=1.0, label="number of patterns")
1201
+ btn_pattern = gr.Button(value="Find patterns among tasks")
1202
+ btn_pattern.click(find_patterns, [df, num_patterns, model, baseline], table)
1203
+
1204
+ category.change(fn=hide_fpm_and_dist_components, inputs=[domain, partition, category], outputs=[num_patterns, btn_pattern, table, col_name, btn_dist, dist_chart])
1205
+ category.change(fn=sync_vis_category, inputs=[domain, partition, category], outputs=[col_name, col_name2])
1206
+ category.change(fn=update_k, inputs=[domain, partition, category], outputs=k)
1207
+
1208
+
1209
+ with gr.Tab("😮 Surprisingness"):
1210
+ gr.Markdown(r"<h2>😮 Find out the tasks a model is surprisingly good or bad at compared to similar tasks</h2>")
1211
+ with gr.Row():
1212
+ domain3 = gr.Radio(domains, label="scenario", scale=2)
1213
+ partition3 = gr.Dropdown([], value=None, label="task space of the following task generator", scale=1)
1214
+ with gr.Row():
1215
+ model3 = gr.Dropdown(MODELS, value=MODELS[0], label="model", interactive=True, visible=True)
1216
+ k3 = gr.Slider(1, 100, 50, step=1.0, label="number of surprising tasks", interactive=True)
1217
+ num_neighbors = gr.Slider(1, 100, 50, step=1.0, label="number of neighbors", interactive=True)
1218
+ rank3 = gr.Radio(['top', 'bottom'], value='top', label=" ", interactive=True, visible=True)
1219
+ domain3.change(fn=update_partition_and_models, inputs=domain3, outputs=[partition3, model3])
1220
+ # partition3.change(fn=update_k, inputs=[domain3, partition3], outputs=k3)
1221
+ with gr.Row():
1222
+ output3 = gr.Plot()
1223
+ with gr.Row():
1224
+ btn = gr.Button(value="Plot")
1225
+ btn.click(plot_surprisingness, [domain3, partition3, model3, rank3, k3, num_neighbors], output3)
1226
+
1227
+
1228
+ # if __name__ == "__main__":
1229
+ demo.launch(share=True)
1230
+
1231
+
db/2d/2d-how-many/embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35cfc281579e9d837218ac6edd58db0475c4728e526eb7d0f005d95772229692
3
+ size 52955299
db/2d/2d-how-many/expanded_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-how-many/gt.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1afea495ff6a5425049d71163c4607af0942d1d750cc3625ecc310ddec123031
3
+ size 828220
db/2d/2d-how-many/instructblip_vicuna13b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e358eb20d0793cef91ee2e78c9257baa270f9ba5d96964b6605651048c7e0c1e
3
+ size 48404804
db/2d/2d-how-many/instructblip_vicuna7b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4347576213a57fdbc4a62c450072db4488ff83f922920daf8a5f8e48423be26e
3
+ size 48404804
db/2d/2d-how-many/llava15_13b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2cae2611f360876a8e9386cd4cae2491512fece11f3c1c963dafedaa5a7b3d4
3
+ size 48404804
db/2d/2d-how-many/llava15_7b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f1afdc0e89ee8acbcf31992b66c746b1e56e0f6b6b295f78a63a3e7f624b33e
3
+ size 48404804
db/2d/2d-how-many/merged_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-how-many/path.json ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-how-many/qa.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e82b51a17f31e4f053b7c6cdc22a3f9dd437dc22436757a4e553ba8506f9cf22
3
+ size 901939
db/2d/2d-how-many/qwenvl_chat_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75307671d6e0c9f8be5375113acac6b53dcaae4a9e5fd3b0ce2ba00fc76c30da
3
+ size 48404804
db/2d/2d-how-many/qwenvl_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0035392b47d7158672b56f54b13d4f738c161cc8939e88b242e909458c6d3ad2
3
+ size 48404804
db/2d/2d-how-many/task_plan.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:deaa99058099c23f6011e923e8abcd4e16b0f941a7172462c606b287a599d0a0
3
+ size 861565
db/2d/2d-what-attribute/embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6956526b21a3ee60a05e86f0c2bad028644d03d42f0216a538ce97016f1b699c
3
+ size 39137443
db/2d/2d-what-attribute/expanded_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-what-attribute/gt.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b367e0455f641e152fe469a0a7832d8f626a4b9657052342d494ad7a180ff42
3
+ size 612316
db/2d/2d-what-attribute/instructblip_vicuna13b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a20ad832ca2a26f5efe5328c1f3e9a88fda8e55b5e26cad3817e1223ec889af
3
+ size 35774420
db/2d/2d-what-attribute/instructblip_vicuna7b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0632f15894b374a7e67a8bb6f53fdcf84d344cc310e097adc1f724364eaf2e7
3
+ size 35774420
db/2d/2d-what-attribute/llava15_13b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99c5a95c3e2de9bd0371060a6886f5bf1b4553120dccdcf38f790116e0712b76
3
+ size 35774420
db/2d/2d-what-attribute/llava15_7b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0ee4d8a941b1f3176e74bea7e158141c2102f4403acf6b7aa3723b7f2e462a4
3
+ size 35774420
db/2d/2d-what-attribute/merged_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-what-attribute/path.json ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-what-attribute/qa.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e5c3b625e90eea95aac69a176772422c1fb0b882cdee838a7e6c56fbe17b50a
3
+ size 980447
db/2d/2d-what-attribute/qwenvl_chat_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac6c585bebeb3aaaed8dac2c7c4420d7d2e780a735f57b3e1e5b617f87ea8160
3
+ size 35774420
db/2d/2d-what-attribute/qwenvl_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c3d7d5aec0726c1f51b7c65ce5c21e9beb9fda2598efe29e96c9788a8b60bdd
3
+ size 35774420
db/2d/2d-what-attribute/task_plan.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e56dc9b74b81ec36a32811ddc3b6a85fbcf23992185f801d52fbce40516974cd
3
+ size 783662
db/2d/2d-what/embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b089c1028a5d2f63609739d53e1c7630f66e4cba66d39a93bdf427565ef29104
3
+ size 39137443
db/2d/2d-what/expanded_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-what/gt.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7fffa0ff67eab9da2ea4dae3a005e1f9b75c10d12e6ca376dd5072297edf2d5
3
+ size 612316
db/2d/2d-what/instructblip_vicuna13b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e224cfb927268e4a43d289bc5b2025d1b2c767bfca14e284c32b1612a3e4d452
3
+ size 35774420
db/2d/2d-what/instructblip_vicuna7b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fef83cca3ed3fa7fce662b37e6085bb1f3373c91e4306eb2a79e79e807854f7
3
+ size 35774420
db/2d/2d-what/llava15_13b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cea6a0cc75170114f6f2725fe71c7fd353038b22d144a1afa2545579cdb5894
3
+ size 35774420
db/2d/2d-what/llava15_7b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ec12e11122cc629c502de7b8981a482692d20779704659feda4250f01bd5e9d
3
+ size 35774420
db/2d/2d-what/merged_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-what/path.json ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-what/qa.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:907fb5a1fe2442e5e4c5837020633d970911f187043e0fc481cf9c06940ab643
3
+ size 807018
db/2d/2d-what/qwenvl_chat_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad629744d9f58f163e74cf6e7126617c331681aae15700f5279df8949155b881
3
+ size 35774420
db/2d/2d-what/qwenvl_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bf48191df156d0618c524296077992fd75417fa2d7b03dbcc1628319d3989a5
3
+ size 35774420
db/2d/2d-what/task_plan.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3113b3b2adbb2bb8088ca4ae86d654da511f783ddbf1ce98fff86cd864336225
3
+ size 783652
db/2d/2d-where-attribute/embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b8b38f46605548c2fc198a3b3d751a999aa28495a95a52e22a29b859f6a820e7
3
+ size 39137443
db/2d/2d-where-attribute/expanded_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-where-attribute/gt.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22fe402eebadfb265f0186a2aee934da6d8144daa10a52173cf583ff4b0acec5
3
+ size 612316
db/2d/2d-where-attribute/instructblip_vicuna13b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48ad5886d1c0f90903c8857883b4b589f147f89e1aaa2bb68cff3e1d5707d8ce
3
+ size 35774420
db/2d/2d-where-attribute/instructblip_vicuna7b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:855e2037394af1adfe6d30906aeb038a2672d3d1ee8773c98844a63bc6aa66f1
3
+ size 35774420
db/2d/2d-where-attribute/llava15_13b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a438d2b5617d58e08e9321b6a9c0907c08fdd2165103617331ebb3d288410d7
3
+ size 35774420
db/2d/2d-where-attribute/llava15_7b_surprise.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93f420d01d0c87147c218eb76277c78d220acab084cf91443704cd5fb25daaf3
3
+ size 35774420
db/2d/2d-where-attribute/merged_data.csv ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-where-attribute/path.json ADDED
The diff for this file is too large to render. See raw diff
 
db/2d/2d-where-attribute/qa.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e45587db39e2316974b4119d2825a27dc8284b46e1d7da643241c5c100ce879
3
+ size 601802