Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from utils.paths import LOGS_PATH, DATA_VIS_PATH, DATA_PATH | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"extension_name", | |
type=str, | |
help="Experiment extension name" | |
) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = get_args() | |
extension_name = args.extension_name | |
thresholds = np.array([0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]) | |
f1_mean, f1_std = [], [] | |
precision_mean, precision_std = [], [] | |
recall_mean, recall_std = [], [] | |
csv_pattern = os.path.join(LOGS_PATH, f"dl_evaluation_{extension_name}", f"dl_evaluation_{extension_name}_{{}}.csv") | |
for threshold in thresholds: | |
performance_csv_filename = csv_pattern.format(threshold) | |
perf_df = pd.read_csv(performance_csv_filename) | |
mean_row = perf_df.iloc[-2] | |
std_row = perf_df.iloc[-1] | |
# Precision, Recall, F1Score | |
f1_mean.append(mean_row['F1Score']) | |
f1_std.append(std_row['F1Score']) | |
precision_mean.append(mean_row['Precision']) | |
precision_std.append(std_row['Precision']) | |
recall_mean.append(mean_row['Recall']) | |
recall_std.append(std_row['Recall']) | |
f1_mean, f1_std = np.array(f1_mean), np.array(f1_std) | |
precision_mean, precision_std = np.array(precision_mean), np.array(precision_std) | |
recall_mean, recall_std = np.array(recall_mean), np.array(recall_std) | |
df_to_save = pd.DataFrame({'threshold': thresholds, | |
'f1score_mean': f1_mean, 'f1score_std': f1_std, | |
'precision_mean': precision_mean, 'precision_std': precision_std, | |
'recall_mean': recall_mean, 'recall_std': recall_std}) | |
csv_filename = os.path.join(DATA_PATH, f"performance_threshold_{extension_name}.csv") | |
df_to_save.to_csv(csv_filename, index=False) | |
plt.figure() | |
plt.plot(thresholds, f1_mean, color='k', linestyle='-', label='F1Score') | |
plt.plot(thresholds, precision_mean, color='k', linestyle='--', label='Precision') | |
plt.plot(thresholds, recall_mean, color='k', linestyle=':', label='Recall') | |
f1_high, f1_low = f1_mean+f1_std, f1_mean-f1_std | |
plt.fill_between(thresholds, f1_high, f1_low, where=f1_high >= f1_low, facecolor='#fccfcf', interpolate=True, alpha=0.5) | |
precision_high, precision_low = precision_mean+precision_std, precision_mean-precision_std | |
plt.fill_between(thresholds, precision_high, precision_low, where=precision_high >= precision_low, facecolor='#cfeffc', interpolate=True, alpha=0.5) | |
recall_high, recall_low = recall_mean+recall_std, recall_mean-recall_std | |
plt.fill_between(thresholds, recall_high, recall_low, where=recall_high >= recall_low, facecolor='#d6ffd1', interpolate=True, alpha=0.5) | |
plt.xlabel('Threshold') | |
plt.xticks(thresholds[1::2]) | |
plt.yticks(np.arange(0.1, 1, 0.1)) | |
plt.ylim(0, 1) | |
plt.grid(alpha=0.3) | |
plt.legend() | |
plot_filename = os.path.join(DATA_VIS_PATH, f"performance_threshold_{extension_name}.png") | |
plt.savefig(plot_filename, bbox_inches='tight', pad_inches=0.0) | |