atom-detection / utils /cf_matrix.py
Romain Graux
Initial commit with ml code and webapp
b2ffc9b
raw
history blame contribute delete
No virus
4.21 kB
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
def make_confusion_matrix(cf,
group_names=None,
categories='auto',
count=True,
percent=True,
cbar=True,
cbar_range=(None, None),
xyticks=True,
xyplotlabels=True,
sum_stats=True,
figsize=None,
cmap='Blues',
title=None):
'''
This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization.
Arguments
---------
cf: confusion matrix to be passed in
group_names: List of strings that represent the labels row by row to be shown in each square.
categories: List of strings containing the categories to be displayed on the x,y axis. Default is 'auto'
count: If True, show the raw number in the confusion matrix. Default is True.
normalize: If True, show the proportions for each category. Default is True.
cbar: If True, show the color bar. The cbar values are based off the values in the confusion matrix.
Default is True.
xyticks: If True, show x and y ticks. Default is True.
xyplotlabels: If True, show 'True Label' and 'Predicted Label' on the figure. Default is True.
sum_stats: If True, display summary statistics below the figure. Default is True.
figsize: Tuple representing the figure size. Default will be the matplotlib rcParams value.
cmap: Colormap of the values displayed from matplotlib.pyplot.cm. Default is 'Blues'
See http://matplotlib.org/examples/color/colormaps_reference.html
title: Title for the heatmap. Default is None.
'''
# CODE TO GENERATE TEXT INSIDE EACH SQUARE
blanks = ['' for i in range(cf.size)]
if group_names and len(group_names) == cf.size:
group_labels = ["{}\n".format(value) for value in group_names]
else:
group_labels = blanks
if count:
group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()]
else:
group_counts = blanks
if percent:
group_percentages = ["{0:.2%}".format(value) for value in cf.flatten() / np.sum(cf)]
else:
group_percentages = blanks
box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels, group_counts, group_percentages)]
box_labels = np.asarray(box_labels).reshape(cf.shape[0], cf.shape[1])
# CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS
if sum_stats:
# Accuracy is sum of diagonal divided by total observations
accuracy = np.trace(cf) / float(np.sum(cf))
# if it is a binary confusion matrix, show some more stats
if len(cf) == 2:
# Metrics for Binary Confusion Matrices
precision = cf[1, 1] / sum(cf[:, 1])
recall = cf[1, 1] / sum(cf[1, :])
f1_score = 2 * precision * recall / (precision + recall)
stats_text = "\n\nAccuracy={:0.3f}\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(
accuracy, precision, recall, f1_score)
else:
stats_text = "\n\nAccuracy={:0.3f}".format(accuracy)
else:
stats_text = ""
# SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS
if figsize == None:
# Get default figure size if not set
figsize = plt.rcParams.get('figure.figsize')
if xyticks == False:
# Do not show categories if xyticks is False
categories = False
# MAKE THE HEATMAP VISUALIZATION
plt.figure(figsize=figsize)
sns.heatmap(cf, annot=box_labels, fmt="", cmap=cmap, cbar=cbar, vmin=cbar_range[0], vmax=cbar_range[1], xticklabels=categories, yticklabels=categories)
if xyplotlabels:
plt.ylabel('True label')
plt.xlabel('Predicted label' + stats_text)
else:
plt.xlabel(stats_text)
if title:
plt.title(title)