Update my_model/results/demo.py
Browse files- my_model/results/demo.py +8 -8
my_model/results/demo.py
CHANGED
@@ -96,7 +96,7 @@ class ResultDemonstrator:
|
|
96 |
|
97 |
def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None:
|
98 |
"""
|
99 |
-
Plots an interactive scatter plot comparing token
|
100 |
|
101 |
Args:
|
102 |
conf (str): The configuration name.
|
@@ -125,13 +125,13 @@ class ResultDemonstrator:
|
|
125 |
legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores]
|
126 |
color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red'])
|
127 |
|
128 |
-
# Retrieve token
|
129 |
-
|
130 |
|
131 |
# Create a DataFrame for the scatter plot
|
132 |
scatter_data = pd.DataFrame({
|
133 |
-
'Index': range(len(
|
134 |
-
'Token Count':
|
135 |
score_name: legend_map
|
136 |
})
|
137 |
|
@@ -143,14 +143,14 @@ class ResultDemonstrator:
|
|
143 |
stroke='black' # Sets the border color to black
|
144 |
).encode(
|
145 |
x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])),
|
146 |
-
y=alt.Y('Token
|
147 |
color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)),
|
148 |
-
tooltip=['Index', 'Token
|
149 |
).interactive() # Enables zoom & pan
|
150 |
|
151 |
chart = chart.properties(
|
152 |
title={
|
153 |
-
"text": f"Token
|
154 |
"color": "black", # Optional color
|
155 |
"fontSize": 20, # Optional font size
|
156 |
"anchor": "middle", # Optional anchor position
|
|
|
96 |
|
97 |
def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None:
|
98 |
"""
|
99 |
+
Plots an interactive scatter plot comparing token count to VQA or EM scores using Altair.
|
100 |
|
101 |
Args:
|
102 |
conf (str): The configuration name.
|
|
|
125 |
legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores]
|
126 |
color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red'])
|
127 |
|
128 |
+
# Retrieve token count from the data
|
129 |
+
token_count = self.main_data[f'tokens_count_{conf}']
|
130 |
|
131 |
# Create a DataFrame for the scatter plot
|
132 |
scatter_data = pd.DataFrame({
|
133 |
+
'Index': range(len(token_count)),
|
134 |
+
'Token Count': token_count,
|
135 |
score_name: legend_map
|
136 |
})
|
137 |
|
|
|
143 |
stroke='black' # Sets the border color to black
|
144 |
).encode(
|
145 |
x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])),
|
146 |
+
y=alt.Y('Token Count', scale=alt.Scale(domain=[token_count.min()-200, token_count.max()+200])),
|
147 |
color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)),
|
148 |
+
tooltip=['Index', 'Token Count', score_name]
|
149 |
).interactive() # Enables zoom & pan
|
150 |
|
151 |
chart = chart.properties(
|
152 |
title={
|
153 |
+
"text": f"Token Count vs {score_name} ({model_configuration})",
|
154 |
"color": "black", # Optional color
|
155 |
"fontSize": 20, # Optional font size
|
156 |
"anchor": "middle", # Optional anchor position
|