import matplotlib.pyplot as plt import numpy as np import seaborn as sns def make_confusion_matrix(cf, group_names=None, categories='auto', count=True, percent=True, cbar=True, cbar_range=(None, None), xyticks=True, xyplotlabels=True, sum_stats=True, figsize=None, cmap='Blues', title=None): ''' This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization. Arguments --------- cf: confusion matrix to be passed in group_names: List of strings that represent the labels row by row to be shown in each square. categories: List of strings containing the categories to be displayed on the x,y axis. Default is 'auto' count: If True, show the raw number in the confusion matrix. Default is True. normalize: If True, show the proportions for each category. Default is True. cbar: If True, show the color bar. The cbar values are based off the values in the confusion matrix. Default is True. xyticks: If True, show x and y ticks. Default is True. xyplotlabels: If True, show 'True Label' and 'Predicted Label' on the figure. Default is True. sum_stats: If True, display summary statistics below the figure. Default is True. figsize: Tuple representing the figure size. Default will be the matplotlib rcParams value. cmap: Colormap of the values displayed from Default is 'Blues' See title: Title for the heatmap. Default is None. ''' # CODE TO GENERATE TEXT INSIDE EACH SQUARE blanks = ['' for i in range(cf.size)] if group_names and len(group_names) == cf.size: group_labels = ["{}\n".format(value) for value in group_names] else: group_labels = blanks if count: group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()] else: group_counts = blanks if percent: group_percentages = ["{0:.2%}".format(value) for value in cf.flatten() / np.sum(cf)] else: group_percentages = blanks box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels, group_counts, group_percentages)] box_labels = np.asarray(box_labels).reshape(cf.shape[0], cf.shape[1]) # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS if sum_stats: # Accuracy is sum of diagonal divided by total observations accuracy = np.trace(cf) / float(np.sum(cf)) # if it is a binary confusion matrix, show some more stats if len(cf) == 2: # Metrics for Binary Confusion Matrices precision = cf[1, 1] / sum(cf[:, 1]) recall = cf[1, 1] / sum(cf[1, :]) f1_score = 2 * precision * recall / (precision + recall) stats_text = "\n\nAccuracy={:0.3f}\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format( accuracy, precision, recall, f1_score) else: stats_text = "\n\nAccuracy={:0.3f}".format(accuracy) else: stats_text = "" # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS if figsize == None: # Get default figure size if not set figsize = plt.rcParams.get('figure.figsize') if xyticks == False: # Do not show categories if xyticks is False categories = False # MAKE THE HEATMAP VISUALIZATION plt.figure(figsize=figsize) sns.heatmap(cf, annot=box_labels, fmt="", cmap=cmap, cbar=cbar, vmin=cbar_range[0], vmax=cbar_range[1], xticklabels=categories, yticklabels=categories) if xyplotlabels: plt.ylabel('True label') plt.xlabel('Predicted label' + stats_text) else: plt.xlabel(stats_text) if title: plt.title(title)